Server: fix seed for multiple slots (#6835)

* Server: add tests for consistent results

* sampling: separate rng per sampling context
This commit is contained in:
Johannes Gäßler 2024-04-24 11:08:36 +02:00 committed by GitHub
parent c0d1b3e03e
commit 28103f4832
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 145 additions and 30 deletions

View File

@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
invalid_param = true;
return true;
}
// This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]);
return true;
}
if (arg == "-t" || arg == "--threads") {

View File

@ -1,4 +1,6 @@
#define LLAMA_API_INTERNAL
#include "sampling.h"
#include <random>
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
result->prev.resize(params.n_prev);
llama_sampling_set_rng_seed(result, params.seed);
return result;
}
@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->cur.clear();
}
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL);
}
ctx->rng.seed(seed);
}
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
sampler_queue(ctx_main, params, cur_p, min_keep);
id = llama_sample_token(ctx_main, &cur_p);
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
//{
// const int n_top = 10;

View File

@ -4,9 +4,10 @@
#include "grammar-parser.h"
#include <random>
#include <string>
#include <vector>
#include <unordered_map>
#include <vector>
// sampler types
enum class llama_sampler_type : char {
@ -20,25 +21,26 @@ enum class llama_sampler_type : char {
// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
@ -79,6 +81,8 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
std::mt19937 rng;
};
#include "common.h"
@ -93,6 +97,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
// Set the sampler seed
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);

View File

@ -30,7 +30,6 @@ int main(int argc, char ** argv){
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt

View File

@ -38,7 +38,6 @@ int main(int argc, char ** argv){
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt

View File

@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
return 1;
}
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
}
}

View File

@ -854,7 +854,7 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.params.seed = json_value(data, "seed", default_params.seed);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
@ -1028,7 +1028,6 @@ struct server_context {
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}
llama_set_rng_seed(ctx, slot.params.seed);
}
slot.command = SLOT_COMMAND_LOAD_PROMPT;

View File

@ -0,0 +1,57 @@
@llama.cpp
@results
Feature: Results
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
And a model file test-model-00001-of-00003.gguf
And 128 as batch size
And 256 KV cache size
And 128 max tokens to predict
Scenario Outline: Multi users completion
Given <n_slots> slots
And continuous batching
Then the server is starting
Then the server is healthy
Given 42 as seed
And a prompt:
"""
Write a very long story about AI.
"""
Given 42 as seed
And a prompt:
"""
Write a very long story about AI.
"""
Given 42 as seed
And a prompt:
"""
Write a very long story about AI.
"""
Given 42 as seed
And a prompt:
"""
Write a very long story about AI.
"""
Given 42 as seed
And a prompt:
"""
Write a very long story about AI.
"""
Given concurrent completion requests
Then the server is busy
Then the server is idle
And all slots are idle
Then all predictions are equal
Examples:
| n_slots |
| 1 |
| 2 |

View File

@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
context.server_metrics = False
context.server_process = None
context.seed = None
context.draft = None
context.server_seed = None
context.user_api_key = None
context.response_format = None
@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
context.n_gpu_layer = ngl
@step('{draft:d} as draft')
def step_draft(context, draft):
context.draft = draft
@step('{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx):
context.n_ctx = n_ctx
@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.completion, predicted_n)
@step('all predictions are equal')
@async_run_until_complete
async def step_predictions_equal(context):
n_completions = await gather_tasks_results(context)
assert n_completions >= 2, "need at least 2 completions"
assert_all_predictions_equal(context.tasks_result)
context.tasks_result = []
@step('the completion is truncated')
def step_assert_completion_truncated(context):
step_assert_completion_truncated(context, '')
@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}')
def assert_all_predictions_equal(completion_responses):
content_0 = completion_responses[0]['content']
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"content 0: {content_0}")
i = 1
for response in completion_responses[1:]:
content = response['content']
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"content {i}: {content}")
assert content == content_0, "contents not equal"
i += 1
async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks)
@ -1148,6 +1180,8 @@ def start_server_background(context):
server_args.extend(['--ubatch-size', context.n_ubatch])
if context.n_gpu_layer:
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
if context.draft is not None:
server_args.extend(['--draft', context.draft])
if context.server_continuous_batching:
server_args.append('--cont-batching')
if context.server_embeddings:

View File

@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
return result;
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
auto & rng = ctx->rng;
int idx = dist(rng);
llama_token result = candidates->data[idx].id;
@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return result;
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
}
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();

View File

@ -987,7 +987,7 @@ extern "C" {
struct llama_context * ctx,
llama_token_data_array * candidates);
/// @details Randomly selects a token from the candidates based on their probabilities.
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
LLAMA_API llama_token llama_sample_token(
struct llama_context * ctx,
llama_token_data_array * candidates);
@ -1074,8 +1074,9 @@ extern "C" {
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
#ifdef LLAMA_API_INTERNAL
#include <vector>
#include <random>
#include <string>
#include <vector>
struct ggml_tensor;
@ -1112,6 +1113,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
llama_partial_utf8 partial_start);
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
#endif // LLAMA_API_INTERNAL
#endif // LLAMA_H