llama : extend batch API to select which logits to output

This commit is contained in:
Georgi Gerganov 2023-09-19 00:24:13 +03:00
parent 897caccdf4
commit fa0e677820
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 46 additions and 6 deletions

View File

@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, }; llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch, params.n_threads)) { if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;

View File

@ -82,6 +82,9 @@ int main(int argc, char ** argv) {
const int n_clients = 4; const int n_clients = 4;
// insert new requests as soon as the previous one is done
const bool hot_swap = true;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("parallel", "log")); log_set_target(log_filename_generator("parallel", "log"));
LOG_TEE("Log start\n"); LOG_TEE("Log start\n");
@ -121,14 +124,23 @@ int main(int argc, char ** argv) {
std::vector<llama_token> batch_token; std::vector<llama_token> batch_token;
std::vector<llama_pos> batch_pos; std::vector<llama_pos> batch_pos;
std::vector<llama_seq_id> batch_seq_id; std::vector<llama_seq_id> batch_seq_id;
std::vector<int8_t> batch_logits;
std::vector<client *> batch_clients; std::vector<client *> batch_clients;
while (true) { int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
float t_avg = 0.0f;
const int32_t n_seq = 128;
while (g_seq_id < n_seq + n_clients) {
uint32_t n_tokens = 0; uint32_t n_tokens = 0;
batch_token.clear(); batch_token.clear();
batch_pos.clear(); batch_pos.clear();
batch_seq_id.clear(); batch_seq_id.clear();
batch_logits.clear();
for (auto & client : clients) { for (auto & client : clients) {
if (client.seq_id == -1) { if (client.seq_id == -1) {
@ -138,6 +150,7 @@ int main(int argc, char ** argv) {
batch_token.push_back(client.sampled); batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded); batch_pos.push_back(client.n_decoded);
batch_seq_id.push_back(client.seq_id); batch_seq_id.push_back(client.seq_id);
batch_logits.push_back(true);
batch_clients.push_back(&client); batch_clients.push_back(&client);
client.n_decoded += 1; client.n_decoded += 1;
client.i_batch = batch_token.size() - 1; client.i_batch = batch_token.size() - 1;
@ -146,7 +159,9 @@ int main(int argc, char ** argv) {
if (batch_token.empty()) { if (batch_token.empty()) {
// all sequences have ended - clear the entire KV cache // all sequences have ended - clear the entire KV cache
llama_kv_cache_rm_tokens(ctx, -1, -1); llama_kv_cache_rm_tokens(ctx, -1, -1);
}
if (hot_swap || batch_token.empty()) {
for (auto & client : clients) { for (auto & client : clients) {
if (client.seq_id == -1) { if (client.seq_id == -1) {
client.seq_id = g_seq_id; client.seq_id = g_seq_id;
@ -166,7 +181,10 @@ int main(int argc, char ** argv) {
batch_pos.push_back(i); batch_pos.push_back(i);
batch_seq_id.push_back(client.seq_id); batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client); batch_clients.push_back(&client);
batch_logits.push_back(false);
} }
batch_logits.back() = true;
client.n_prompt = prompt_tokens.size(); client.n_prompt = prompt_tokens.size();
client.n_decoded = prompt_tokens.size(); client.n_decoded = prompt_tokens.size();
client.i_batch = batch_token.size() - 1; client.i_batch = batch_token.size() - 1;
@ -186,6 +204,7 @@ int main(int argc, char ** argv) {
nullptr, nullptr,
batch_pos.data() + i, batch_pos.data() + i,
batch_seq_id.data() + i, batch_seq_id.data() + i,
batch_logits.data() + i,
0, 0, 0, // unused 0, 0, 0, // unused
}; };
@ -232,14 +251,20 @@ int main(int argc, char ** argv) {
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n", printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt, client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
(t_main_end - client.t_start_prompt) / 1e6,
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6, (double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6, (double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6, (double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
::trim(client.input).c_str(), ::trim(client.input).c_str(),
::trim(client.response).c_str()); ::trim(client.response).c_str());
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded - client.n_prompt;
t_avg += (t_main_end - client.t_start_prompt) / 1e6;
client.seq_id = -1; client.seq_id = -1;
} }
@ -248,6 +273,11 @@ int main(int argc, char ** argv) {
} }
} }
LOG_TEE("\n\n");
LOG_TEE("Total prompt tokens: %d\n", n_total_prompt);
LOG_TEE("Total gen tokens: %d\n", n_total_gen);
LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq);
LOG_TEE("\n\n"); LOG_TEE("\n\n");
llama_print_timings(ctx); llama_print_timings(ctx);

View File

@ -4140,7 +4140,16 @@ static bool llama_eval_internal(
if (lctx.logits_all) { if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens); logits_out.resize(n_vocab * n_tokens);
if (batch.logits) {
for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
}
} else {
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
}
} else { } else {
// return result for just the last token // return result for just the last token
logits_out.resize(n_vocab); logits_out.resize(n_vocab);
@ -7318,7 +7327,7 @@ int llama_eval_embd(
int n_threads) { int n_threads) {
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, }; llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
if (!llama_eval_internal(*ctx, batch, n_threads)) { if (!llama_eval_internal(*ctx, batch, n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@ -7346,6 +7355,7 @@ struct llama_batch llama_batch_get_one(
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*all_pos_0 =*/ pos_0, /*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1, /*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id, /*all_seq_id =*/ seq_id,

View File

@ -70,11 +70,11 @@ extern "C" {
typedef struct llama_batch { typedef struct llama_batch {
uint32_t n_tokens; uint32_t n_tokens;
// TODO: not sure about these consts - might just get in the way all the time with no benefit
const llama_token * token; const llama_token * token;
const float * embd; const float * embd;
const llama_pos * pos; const llama_pos * pos;
const llama_seq_id * seq_id; const llama_seq_id * seq_id;
const int8_t * logits; // if 0, do not extract logits for that token
// NOTE: helpers for smooth API transition - can be deprecated in the future // NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below // for future-proof code, use the above fields instead and ignore everything below