diff --git a/common/common.cpp b/common/common.cpp index 06f252ea6..a0d1f8d59 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } + // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context. params.seed = std::stoul(argv[i]); + sparams.seed = std::stoul(argv[i]); return true; } if (arg == "-t" || arg == "--threads") { diff --git a/common/sampling.cpp b/common/sampling.cpp index 45d68b26c..f24665501 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,4 +1,6 @@ +#define LLAMA_API_INTERNAL #include "sampling.h" +#include struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + llama_sampling_set_rng_seed(result, params.seed); + return result; } @@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->cur.clear(); } +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); @@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl( sampler_queue(ctx_main, params, cur_p, min_keep); - id = llama_sample_token(ctx_main, &cur_p); + id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); //{ // const int n_top = 10; diff --git a/common/sampling.h b/common/sampling.h index 639b819ab..cf7081e36 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -4,9 +4,10 @@ #include "grammar-parser.h" +#include #include -#include #include +#include // sampler types enum class llama_sampler_type : char { @@ -20,25 +21,26 @@ enum class llama_sampler_type : char { // sampling parameters typedef struct llama_sampling_params { - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -79,6 +81,8 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; + + std::mt19937 rng; }; #include "common.h" @@ -93,6 +97,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx); // - reset grammar void llama_sampling_reset(llama_sampling_context * ctx); +// Set the sampler seed +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); + // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 41b62c2fe..87ecc0a4f 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -30,7 +30,6 @@ int main(int argc, char ** argv){ // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_set_rng_seed(ctx, params.seed); GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 9526e898f..eebbd00a5 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -38,7 +38,6 @@ int main(int argc, char ** argv){ // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_set_rng_seed(ctx, params.seed); GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1180734b9..a74d4d9c7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -240,7 +240,6 @@ int main(int argc, char ** argv) { return 1; } session_tokens.resize(n_token_count_out); - llama_set_rng_seed(ctx, params.seed); LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size()); } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 25bc29639..68c63f9f1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -854,7 +854,7 @@ struct server_context { slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.params.seed = json_value(data, "seed", default_params.seed); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); @@ -1028,7 +1028,6 @@ struct server_context { send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } - llama_set_rng_seed(ctx, slot.params.seed); } slot.command = SLOT_COMMAND_LOAD_PROMPT; diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature new file mode 100644 index 000000000..f17120f7b --- /dev/null +++ b/examples/server/tests/features/results.feature @@ -0,0 +1,57 @@ +@llama.cpp +@results +Feature: Results + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models + And a model file test-model-00001-of-00003.gguf + And 128 as batch size + And 256 KV cache size + And 128 max tokens to predict + + Scenario Outline: Multi users completion + Given slots + And continuous batching + Then the server is starting + Then the server is healthy + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given concurrent completion requests + Then the server is busy + Then the server is idle + And all slots are idle + Then all predictions are equal + Examples: + | n_slots | + | 1 | + | 2 | diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index ca400efa4..f71e0d706 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port): context.server_metrics = False context.server_process = None context.seed = None + context.draft = None context.server_seed = None context.user_api_key = None context.response_format = None @@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl): context.n_gpu_layer = ngl +@step('{draft:d} as draft') +def step_draft(context, draft): + context.draft = draft + + @step('{n_ctx:d} KV cache size') def step_n_ctx(context, n_ctx): context.n_ctx = n_ctx @@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n): assert_n_tokens_predicted(context.completion, predicted_n) +@step('all predictions are equal') +@async_run_until_complete +async def step_predictions_equal(context): + n_completions = await gather_tasks_results(context) + assert n_completions >= 2, "need at least 2 completions" + assert_all_predictions_equal(context.tasks_result) + context.tasks_result = [] + + @step('the completion is truncated') def step_assert_completion_truncated(context): step_assert_completion_truncated(context, '') @@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' f' {n_predicted} <> {expected_predicted_n}') +def assert_all_predictions_equal(completion_responses): + content_0 = completion_responses[0]['content'] + + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': + print(f"content 0: {content_0}") + + i = 1 + for response in completion_responses[1:]: + content = response['content'] + + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': + print(f"content {i}: {content}") + + assert content == content_0, "contents not equal" + + i += 1 + async def gather_tasks_results(context): n_tasks = len(context.concurrent_tasks) @@ -1148,6 +1180,8 @@ def start_server_background(context): server_args.extend(['--ubatch-size', context.n_ubatch]) if context.n_gpu_layer: server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) + if context.draft is not None: + server_args.extend(['--draft', context.draft]) if context.server_continuous_batching: server_args.append('--cont-batching') if context.server_embeddings: diff --git a/llama.cpp b/llama.cpp index e4ca34bd1..3a4a03d8f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra } std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto & rng = ctx->rng; int idx = dist(rng); llama_token result = candidates->data[idx].id; @@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(ctx, candidates, ctx->rng); +} + void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); diff --git a/llama.h b/llama.h index 4effca42c..7bfd13740 100644 --- a/llama.h +++ b/llama.h @@ -987,7 +987,7 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities. + /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. LLAMA_API llama_token llama_sample_token( struct llama_context * ctx, llama_token_data_array * candidates); @@ -1074,8 +1074,9 @@ extern "C" { // Internal API to be implemented by llama.cpp and used by tests/benchmarks only #ifdef LLAMA_API_INTERNAL -#include +#include #include +#include struct ggml_tensor; @@ -1112,6 +1113,10 @@ std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start); +// Randomly selects a token from the candidates based on their probabilities using given std::mt19937. +// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); + #endif // LLAMA_API_INTERNAL #endif // LLAMA_H