mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
llama : add llama_beam_search() (#2267)
* Add llama_beam_search(). * Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token(). * Add space around * pointers and & references. * Add spaces around comparison and assignment operators. * Prefer west const. * Use llama_ prefix for structs in global namespace. * Delete obsolete comment from an earlier revision. * Change eos to eob in llama_beam and llama_beam_view structs.
This commit is contained in:
parent
28b2c996ca
commit
c82742ac9c
@ -28,6 +28,7 @@ struct gpt_params {
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
||||
float rope_freq_base = 10000.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
|
||||
|
||||
|
@ -25,6 +25,7 @@ else()
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(embd-input)
|
||||
add_subdirectory(llama-bench)
|
||||
add_subdirectory(beam_search)
|
||||
if (LLAMA_METAL)
|
||||
add_subdirectory(metal)
|
||||
endif()
|
||||
|
8
examples/beam_search/CMakeLists.txt
Normal file
8
examples/beam_search/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
||||
set(TARGET beam_search)
|
||||
add_executable(${TARGET} beam_search.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
188
examples/beam_search/beam_search.cpp
Normal file
188
examples/beam_search/beam_search.cpp
Normal file
@ -0,0 +1,188 @@
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#endif
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "build-info.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#elif defined (_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#define NOMINMAX
|
||||
#include <windows.h>
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
// Used for debugging to print out beam tokens.
|
||||
struct ostream_beam_view {
|
||||
llama_context * ctx;
|
||||
llama_beam_view beam_view;
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
|
||||
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
|
||||
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
|
||||
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
|
||||
}
|
||||
return os << ')';
|
||||
}
|
||||
|
||||
// Put here anything you want back in beam_search_callback().
|
||||
struct beam_search_callback_data {
|
||||
llama_context * ctx;
|
||||
std::vector<llama_token> response;
|
||||
};
|
||||
|
||||
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
|
||||
// For example, eob can be flagged due to maximum token length, stop words, etc.
|
||||
bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
|
||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
|
||||
}
|
||||
|
||||
// Function matching type llama_beam_search_callback_fn_t.
|
||||
// Custom callback example is called each time the beams lengths increase:
|
||||
// * Show progress by printing ',' following by number of convergent beam tokens if any.
|
||||
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
|
||||
// This is also called when the stop condition is met.
|
||||
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
|
||||
void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) {
|
||||
auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr);
|
||||
// Mark beams as EOS as needed.
|
||||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||
llama_beam_view& beam_view = beams_state.beam_views[i];
|
||||
if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
|
||||
beam_view.eob = true;
|
||||
}
|
||||
}
|
||||
printf(","); // Show progress
|
||||
if (const size_t n = beams_state.common_prefix_length) {
|
||||
callback_data.response.resize(callback_data.response.size() + n);
|
||||
assert(0u < beams_state.n_beams);
|
||||
const llama_token * tokens = beams_state.beam_views[0].tokens;
|
||||
std::copy(tokens, tokens + n, callback_data.response.end() - n);
|
||||
printf("%lu", n);
|
||||
}
|
||||
fflush(stdout);
|
||||
#if 1 // DEBUG: print current beams for this iteration
|
||||
std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
|
||||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||
std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv)
|
||||
{
|
||||
gpt_params params;
|
||||
//params.n_gpu_layers = 200;
|
||||
|
||||
//---------------------------------
|
||||
// Print help :
|
||||
//---------------------------------
|
||||
|
||||
if ( argc < 2 || argv[1][0] == '-' )
|
||||
{
|
||||
printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] );
|
||||
return 1 ;
|
||||
}
|
||||
|
||||
//---------------------------------
|
||||
// Load parameters :
|
||||
//---------------------------------
|
||||
|
||||
params.model = argv[1];
|
||||
|
||||
params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2;
|
||||
|
||||
if ( argc > 3 )
|
||||
{
|
||||
params.prompt = argv[3];
|
||||
}
|
||||
|
||||
if ( params.prompt.empty() )
|
||||
{
|
||||
params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
|
||||
}
|
||||
|
||||
//---------------------------------
|
||||
// Init LLM :
|
||||
//---------------------------------
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params( params );
|
||||
|
||||
if ( model == NULL )
|
||||
{
|
||||
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
|
||||
return 1;
|
||||
}
|
||||
|
||||
//---------------------------------
|
||||
// Tokenize the prompt :
|
||||
//---------------------------------
|
||||
|
||||
std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const size_t max_context_size = llama_n_ctx( ctx );
|
||||
const size_t max_tokens_list_size = max_context_size - 4 ;
|
||||
|
||||
if (tokens_list.size() > max_tokens_list_size)
|
||||
{
|
||||
fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
|
||||
__func__ , tokens_list.size() , max_tokens_list_size );
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf( stderr, "\n\n" );
|
||||
|
||||
// Print the tokens from the prompt :
|
||||
|
||||
for( auto id : tokens_list )
|
||||
{
|
||||
std::cout << llama_token_to_str(ctx, id);
|
||||
}
|
||||
std::cout << std::flush;
|
||||
|
||||
int n_past = llama_get_kv_cache_token_count(ctx);
|
||||
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
|
||||
{
|
||||
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
|
||||
return 1;
|
||||
}
|
||||
n_past += tokens_list.size();
|
||||
|
||||
beam_search_callback_data callback_data{ctx, {}};
|
||||
size_t const beam_width = static_cast<size_t>(params.n_beams);
|
||||
int const n_predict = 256;
|
||||
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
|
||||
|
||||
std::cout << "\n\n";
|
||||
for (llama_token const token_id : callback_data.response) {
|
||||
std::cout << llama_token_to_str(ctx,token_id);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
llama_free( ctx );
|
||||
llama_free_model( model );
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
@ -1209,6 +1209,62 @@ static void log_server_request(const Request &req, const Response &res)
|
||||
});
|
||||
}
|
||||
|
||||
bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
|
||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
|
||||
}
|
||||
|
||||
// Function matching type llama_beam_search_callback_fn_t.
|
||||
// Custom callback example is called each time the beams lengths increase:
|
||||
// * Show progress by printing ',' following by number of convergent beam tokens if any.
|
||||
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
|
||||
// This is also called when the stop condition is met.
|
||||
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
|
||||
void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
|
||||
auto & llama = *static_cast<llama_server_context*>(callback_data);
|
||||
// Mark beams as EOS as needed.
|
||||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||
llama_beam_view& beam_view = beams_state.beam_views[i];
|
||||
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
|
||||
beam_view.eob = true;
|
||||
}
|
||||
}
|
||||
printf(","); // Show progress
|
||||
if (const size_t n = beams_state.common_prefix_length) {
|
||||
llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
|
||||
assert(0u < beams_state.n_beams);
|
||||
const llama_token * tokens = beams_state.beam_views[0].tokens;
|
||||
const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
|
||||
std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
|
||||
printf("%lu", n);
|
||||
}
|
||||
fflush(stdout);
|
||||
#if 0 // DEBUG: print current beams for this iteration
|
||||
std::cout << "\n\nCurrent beams:\n";
|
||||
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
|
||||
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
struct token_translator {
|
||||
llama_context * ctx;
|
||||
std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
|
||||
std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
|
||||
};
|
||||
|
||||
void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) {
|
||||
auto & gtps = llama.generated_token_probs;
|
||||
auto translator = token_translator{llama.ctx};
|
||||
auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
|
||||
const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
|
||||
if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
|
||||
llama.generated_text.reserve(llama.generated_text.size() + len);
|
||||
}
|
||||
for (const completion_token_output & cto : gtps) {
|
||||
llama.generated_text += translator(cto);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
// own arguments required by this example
|
||||
@ -1291,6 +1347,13 @@ int main(int argc, char **argv)
|
||||
llama.beginCompletion();
|
||||
|
||||
if (!llama.stream) {
|
||||
if (llama.params.n_beams) {
|
||||
// Fill llama.generated_token_probs vector with final beam.
|
||||
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
|
||||
llama.n_past, llama.n_remain, llama.params.n_threads);
|
||||
// Translate llama.generated_token_probs to llama.generated_text.
|
||||
append_to_generated_text_from_generated_token_probs(llama);
|
||||
} else {
|
||||
size_t stop_pos = std::string::npos;
|
||||
|
||||
while (llama.has_next_token) {
|
||||
@ -1308,6 +1371,7 @@ int main(int argc, char **argv)
|
||||
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
|
||||
llama.generated_text.end());
|
||||
}
|
||||
}
|
||||
|
||||
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
|
||||
|
||||
|
251
llama.cpp
251
llama.cpp
@ -4326,6 +4326,257 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
//
|
||||
// Beam search
|
||||
//
|
||||
|
||||
struct llama_beam {
|
||||
std::vector<llama_token> tokens;
|
||||
float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||
bool eob; // Initialize end-of-beam to false. Callback sets this to true.
|
||||
// Sort beams by probability. In case of ties, prefer beams at eob.
|
||||
bool operator<(const llama_beam & rhs) const {
|
||||
return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob);
|
||||
}
|
||||
// Shift off first n tokens and discard them.
|
||||
void shift_tokens(const size_t n) {
|
||||
if (n) {
|
||||
std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
|
||||
tokens.resize(tokens.size() - n);
|
||||
}
|
||||
}
|
||||
llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; }
|
||||
};
|
||||
|
||||
// A struct for calculating logit-related info.
|
||||
struct llama_logit_info {
|
||||
const float * const logits;
|
||||
const int n_vocab;
|
||||
const float max_l;
|
||||
const float normalizer;
|
||||
struct sum_exp {
|
||||
float max_l;
|
||||
float operator()(float sum, float l) const { return sum + std::exp(l - max_l); }
|
||||
};
|
||||
llama_logit_info(llama_context * ctx)
|
||||
: logits(llama_get_logits(ctx))
|
||||
, n_vocab(llama_n_vocab(ctx))
|
||||
, max_l(*std::max_element(logits, logits + n_vocab))
|
||||
, normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l}))
|
||||
{ }
|
||||
llama_token_data get_token_data(const llama_token token_id) const {
|
||||
constexpr auto p = std::numeric_limits<float>::quiet_NaN(); // never used
|
||||
return {token_id, logits[token_id], p};
|
||||
}
|
||||
// Return top k token_data by logit.
|
||||
std::vector<llama_token_data> top_k(size_t k) {
|
||||
std::vector<llama_token_data> min_heap; // min-heap by logit
|
||||
const llama_token k_min = std::min(static_cast<llama_token>(k), n_vocab);
|
||||
min_heap.reserve(k_min);
|
||||
for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) {
|
||||
min_heap.push_back(get_token_data(token_id));
|
||||
}
|
||||
auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; };
|
||||
std::make_heap(min_heap.begin(), min_heap.end(), comp);
|
||||
for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) {
|
||||
if (min_heap.front().logit < logits[token_id]) {
|
||||
std::pop_heap(min_heap.begin(), min_heap.end(), comp);
|
||||
min_heap.back().id = token_id;
|
||||
min_heap.back().logit = logits[token_id];
|
||||
std::push_heap(min_heap.begin(), min_heap.end(), comp);
|
||||
}
|
||||
}
|
||||
return min_heap;
|
||||
}
|
||||
float probability_from_logit(float logit) {
|
||||
return normalizer * std::exp(logit - max_l);
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_beam_search_data {
|
||||
llama_context * ctx;
|
||||
size_t n_beams;
|
||||
int n_past;
|
||||
int n_predict;
|
||||
int n_threads;
|
||||
std::vector<llama_beam> beams;
|
||||
std::vector<llama_beam> next_beams;
|
||||
|
||||
// Re-calculated on each loop iteration
|
||||
size_t common_prefix_length;
|
||||
|
||||
// Used to communicate to/from callback on beams state.
|
||||
std::vector<llama_beam_view> beam_views;
|
||||
|
||||
llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
|
||||
: ctx(ctx)
|
||||
, n_beams(n_beams)
|
||||
, n_past(n_past)
|
||||
, n_predict(n_predict)
|
||||
, n_threads(n_threads)
|
||||
, beam_views(n_beams) {
|
||||
beams.reserve(n_beams);
|
||||
next_beams.reserve(n_beams);
|
||||
}
|
||||
|
||||
// Collapse beams to a single beam given by index.
|
||||
void collapse_beams(const size_t beam_idx) {
|
||||
if (0u < beam_idx) {
|
||||
std::swap(beams[0], beams[beam_idx]);
|
||||
}
|
||||
beams.resize(1);
|
||||
}
|
||||
|
||||
// Min-heaps are used to efficiently collect the top-k elements (k=n_beams).
|
||||
// The repetative patterns below reflect the 2 stages of heaps:
|
||||
// * Gather elements until the vector is full, then call std::make_heap() on it.
|
||||
// * If the heap is full and a new element is found that should be included, pop the
|
||||
// least element to the back(), replace it with the new, then push it into the heap.
|
||||
void fill_next_beams_by_top_probabilities(llama_beam & beam) {
|
||||
// Min-heaps use a greater-than comparator.
|
||||
const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; };
|
||||
if (beam.eob) {
|
||||
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
|
||||
if (next_beams.size() < n_beams) {
|
||||
next_beams.push_back(std::move(beam));
|
||||
if (next_beams.size() == n_beams) {
|
||||
std::make_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
}
|
||||
} else if (next_beams.front().p < beam.p) {
|
||||
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
next_beams.back() = std::move(beam);
|
||||
std::push_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
}
|
||||
} else {
|
||||
// beam is not at end-of-sentence, so branch with next top_k tokens.
|
||||
if (!beam.tokens.empty()) {
|
||||
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
|
||||
}
|
||||
llama_logit_info logit_info(ctx);
|
||||
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
|
||||
size_t i=0;
|
||||
if (next_beams.size() < n_beams) {
|
||||
for (; next_beams.size() < n_beams ; ++i) {
|
||||
llama_beam next_beam = beam;
|
||||
next_beam.tokens.push_back(next_tokens[i].id);
|
||||
next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
|
||||
next_beams.push_back(std::move(next_beam));
|
||||
}
|
||||
std::make_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
} else {
|
||||
for (; next_beams.front().p == 0.0f ; ++i) {
|
||||
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
next_beams.back() = beam;
|
||||
next_beams.back().tokens.push_back(next_tokens[i].id);
|
||||
next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit);
|
||||
std::push_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
}
|
||||
}
|
||||
for (; i < n_beams ; ++i) {
|
||||
const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit);
|
||||
if (next_beams.front().p < next_p) {
|
||||
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
next_beams.back() = beam;
|
||||
next_beams.back().tokens.push_back(next_tokens[i].id);
|
||||
next_beams.back().p = next_p;
|
||||
std::push_heap(next_beams.begin(), next_beams.end(), comp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find common_prefix_length based on beams.
|
||||
// Requires beams is not empty.
|
||||
size_t find_common_prefix_length() {
|
||||
size_t common_prefix_length = beams[0].tokens.size();
|
||||
for (size_t i = 1 ; i < beams.size() ; ++i) {
|
||||
common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
|
||||
for (size_t j = 0 ; j < common_prefix_length ; ++j) {
|
||||
if (beams[0].tokens[j] != beams[i].tokens[j]) {
|
||||
common_prefix_length = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return common_prefix_length;
|
||||
}
|
||||
|
||||
// Construct beams_state to send back to caller via the callback function.
|
||||
// Side effect: set common_prefix_length = find_common_prefix_length();
|
||||
llama_beams_state get_beams_state(const bool last_call) {
|
||||
for (size_t i = 0 ; i < beams.size() ; ++i) {
|
||||
beam_views[i] = beams[i].view();
|
||||
}
|
||||
common_prefix_length = find_common_prefix_length();
|
||||
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
|
||||
}
|
||||
|
||||
// Loop:
|
||||
// * while i < n_predict, AND
|
||||
// * any of the beams have not yet reached end-of-beam (eob), AND
|
||||
// * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
|
||||
// (since all other beam probabilities can only decrease)
|
||||
void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) {
|
||||
beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob.
|
||||
const auto not_eob = [](const llama_beam & beam) { return !beam.eob; };
|
||||
for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) &&
|
||||
!beams[top_beam_index()].eob ; ++i) {
|
||||
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
|
||||
update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed.
|
||||
if (common_prefix_length) {
|
||||
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
|
||||
n_past += common_prefix_length;
|
||||
}
|
||||
// Zero-out next_beam probabilities to place them last in following min-heap.
|
||||
std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; });
|
||||
for (llama_beam & beam : beams) {
|
||||
beam.shift_tokens(common_prefix_length);
|
||||
fill_next_beams_by_top_probabilities(beam);
|
||||
}
|
||||
// next_beams become the beams of next/final iteration. Swap them to re-use memory.
|
||||
beams.swap(next_beams);
|
||||
renormalize_beam_probabilities(beams);
|
||||
}
|
||||
collapse_beams(top_beam_index());
|
||||
callback(callback_data, get_beams_state(true));
|
||||
}
|
||||
|
||||
// As beams grow, the cumulative probabilities decrease.
|
||||
// Renormalize them to avoid floating point underflow.
|
||||
static void renormalize_beam_probabilities(std::vector<llama_beam> & beams) {
|
||||
const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; };
|
||||
const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
|
||||
std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; });
|
||||
}
|
||||
|
||||
// Assumes beams is non-empty. Uses llama_beam::operator<() for ordering.
|
||||
size_t top_beam_index() {
|
||||
return std::max_element(beams.begin(), beams.end()) - beams.begin();
|
||||
}
|
||||
|
||||
// Copy (p,eob) for each beam which may have been changed by the callback.
|
||||
void update_beams_from_beam_views() {
|
||||
for (size_t i = 0 ; i < beams.size() ; ++i) {
|
||||
beams[i].p = beam_views[i].p;
|
||||
beams[i].eob = beam_views[i].eob;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void llama_beam_search(llama_context * ctx,
|
||||
llama_beam_search_callback_fn_t callback, void * callback_data,
|
||||
size_t n_beams, int n_past, int n_predict, int n_threads) {
|
||||
assert(ctx);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads);
|
||||
|
||||
beam_search_data.loop(callback, callback_data);
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
}
|
||||
|
||||
//
|
||||
// quantization
|
||||
//
|
||||
|
37
llama.h
37
llama.h
@ -469,6 +469,43 @@ extern "C" {
|
||||
/// @details Accepts the sampled token into the grammar
|
||||
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
||||
|
||||
//
|
||||
// Beam search
|
||||
//
|
||||
|
||||
struct llama_beam_view {
|
||||
const llama_token * tokens;
|
||||
size_t n_tokens;
|
||||
float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||
bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||
};
|
||||
|
||||
// Passed to beam_search_callback function.
|
||||
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
|
||||
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
|
||||
// These pointers are valid only during the synchronous callback, so should not be saved.
|
||||
struct llama_beams_state {
|
||||
llama_beam_view * beam_views;
|
||||
size_t n_beams; // Number of elements in beam_views[].
|
||||
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
|
||||
bool last_call; // True iff this is the last callback invocation.
|
||||
};
|
||||
|
||||
// Type of pointer to the beam_search_callback function.
|
||||
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
||||
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
||||
typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state);
|
||||
|
||||
/// @details Deterministically returns entire sentence constructed by a beam search.
|
||||
/// @param ctx Pointer to the llama_context.
|
||||
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
|
||||
/// @param callback_data A pointer that is simply passed back to callback.
|
||||
/// @param n_beams Number of beams to use.
|
||||
/// @param n_past Number of tokens already evaluated.
|
||||
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
|
||||
/// @param n_threads Number of threads as passed to llama_eval().
|
||||
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
|
Loading…
Reference in New Issue
Block a user