llama: Don't double count the sampling time (#2107)

This commit is contained in:
Howard Su 2023-07-05 18:31:23 +08:00 committed by GitHub
parent 9e4475f5cf
commit 051c70dcd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1905,10 +1905,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
return; return;
} }
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(ctx, candidates); llama_sample_softmax(ctx, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the cumulative probabilities // Compute the cumulative probabilities
float cum_sum = 0.0f; float cum_sum = 0.0f;
size_t last_idx = candidates->size; size_t last_idx = candidates->size;
@ -1937,9 +1937,8 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
return; return;
} }
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(nullptr, candidates); llama_sample_softmax(nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the first and second derivatives // Compute the first and second derivatives
std::vector<float> first_derivatives(candidates->size - 1); std::vector<float> first_derivatives(candidates->size - 1);
@ -1991,11 +1990,11 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
return; return;
} }
const int64_t t_start_sample_us = ggml_time_us();
// Compute the softmax of logits and calculate entropy // Compute the softmax of logits and calculate entropy
llama_sample_softmax(nullptr, candidates); llama_sample_softmax(nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
float entropy = 0.0f; float entropy = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
entropy += -candidates->data[i].p * logf(candidates->data[i].p); entropy += -candidates->data[i].p * logf(candidates->data[i].p);
@ -2164,13 +2163,11 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
if (ctx) { if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++;
} }
return X; return X;
} }
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
assert(ctx);
int64_t t_start_sample_us; int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us(); t_start_sample_us = ggml_time_us();
@ -2185,13 +2182,14 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
candidates->size = 1; candidates->size = 1;
} }
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Normalize the probabilities of the remaining words // Normalize the probabilities of the remaining words
llama_sample_softmax(ctx, candidates); llama_sample_softmax(ctx, candidates);
// Sample the next word X from the remaining words // Sample the next word X from the remaining words
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
llama_token X = llama_sample_token(ctx, candidates); llama_token X = llama_sample_token(ctx, candidates);
t_start_sample_us = ggml_time_us(); t_start_sample_us = ggml_time_us();