server: tests: add truncated prompt tests, better kv cache size (#5933)

* server: tests: add truncated prompt tests, better size

* server, tests : update regex

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Pierrick Hymbert 2024-03-09 10:30:04 +01:00 committed by GitHub
parent c2101a2e90
commit fd72d2d2a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 81 additions and 23 deletions

View File

@ -1128,6 +1128,7 @@ struct server_context {
LOG_VERBOSE("stopped by limit", { LOG_VERBOSE("stopped by limit", {
{"id_slot", slot.id}, {"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_decoded", slot.n_decoded}, {"n_decoded", slot.n_decoded},
{"n_predict", slot.params.n_predict}, {"n_predict", slot.params.n_predict},
}); });
@ -1141,6 +1142,8 @@ struct server_context {
} }
LOG_VERBOSE("next token", { LOG_VERBOSE("next token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"token", result.tok}, {"token", result.tok},
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, {"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
{"has_next_token", slot.has_next_token}, {"has_next_token", slot.has_next_token},
@ -1750,6 +1753,15 @@ struct server_context {
slot.n_past = 0; slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size(); slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("prompt tokenized", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
if (slot.embedding) { if (slot.embedding) {
// this prompt is too large to process - discard it // this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_batch) { if (slot.n_prompt_tokens > n_batch) {
@ -1788,10 +1800,13 @@ struct server_context {
slot.n_prompt_tokens = prompt_tokens.size(); slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx}, {"id_slot", slot.id},
{"n_keep", slot.params.n_keep}, {"id_task", slot.id_task},
{"n_left", n_left}, {"n_ctx", slot.n_ctx},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, {"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
}); });
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);

View File

@ -6,8 +6,8 @@ Feature: Parallel
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And 42 as server seed And 42 as server seed
And 512 as batch size And 128 as batch size
And 64 KV cache size And 256 KV cache size
And 2 slots And 2 slots
And continuous batching And continuous batching
Then the server is starting Then the server is starting
@ -76,6 +76,7 @@ Feature: Parallel
| disabled | 128 | | disabled | 128 |
| enabled | 64 | | enabled | 64 |
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969 Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
Given a prompt: Given a prompt:
""" """

View File

@ -10,11 +10,10 @@ Feature: llama.cpp server
# KV Cache corresponds to the total amount of tokens # KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130 # that can be stored across all independent sequences: #4130
# see --ctx-size and #5568 # see --ctx-size and #5568
And 32 KV cache size And 256 KV cache size
And 512 as batch size And 32 as batch size
And 1 slots And 2 slots
And embeddings extraction And 64 server max tokens to predict
And 32 server max tokens to predict
And prometheus compatible metrics exposed And prometheus compatible metrics exposed
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
@ -23,18 +22,35 @@ Feature: llama.cpp server
Then the server is ready Then the server is ready
And all slots are idle And all slots are idle
Scenario Outline: Completion Scenario Outline: Completion
Given a prompt <prompt> Given a prompt <prompt>
And <n_predict> max tokens to predict And <n_predict> max tokens to predict
And a completion request with no api error And a completion request with no api error
Then <n_predicted> tokens are predicted matching <re_content> Then <n_predicted> tokens are predicted matching <re_content>
And the completion is <truncated> truncated
And <n_prompt> prompt tokens are processed
And prometheus metrics are exposed And prometheus metrics are exposed
And metric llamacpp:tokens_predicted is <n_predicted> And metric llamacpp:tokens_predicted is <n_predicted>
Examples: Prompts Examples: Prompts
| prompt | n_predict | re_content | n_predicted | | prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
| I believe the meaning of life is | 8 | (read\|going)+ | 8 | | I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
| Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 | | Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids)+ | 46 | 64 | not |
Scenario: Completion prompt truncated
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns
And the completion is truncated
And 109 prompt tokens are processed
Scenario Outline: OAI Compatibility Scenario Outline: OAI Compatibility
Given a model <model> Given a model <model>
@ -44,11 +60,14 @@ Feature: llama.cpp server
And streaming is <enable_streaming> And streaming is <enable_streaming>
Given an OAI compatible chat completions request with no api error Given an OAI compatible chat completions request with no api error
Then <n_predicted> tokens are predicted matching <re_content> Then <n_predicted> tokens are predicted matching <re_content>
And <n_prompt> prompt tokens are processed
And the completion is <truncated> truncated
Examples: Prompts Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming | | model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
| llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled | | llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | |
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize
When tokenizing: When tokenizing:

View File

@ -196,12 +196,30 @@ async def step_request_completion(context, api_error):
@step(u'{predicted_n:d} tokens are predicted matching {re_content}') @step(u'{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content): def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content) context.completion = context.tasks_result.pop()
assert_n_tokens_predicted(context.completion, predicted_n, re_content)
@step(u'{predicted_n:d} tokens are predicted') @step(u'{predicted_n:d} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n): def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n) context.completion = context.tasks_result.pop()
assert_n_tokens_predicted(context.completion, predicted_n)
@step(u'the completion is truncated')
def step_assert_completion_truncated(context):
step_assert_completion_truncated(context, '')
@step(u'the completion is {truncated} truncated')
def step_assert_completion_truncated(context, truncated):
truncated = truncated != "not"
assert context.completion['truncated'] == truncated, f'{context.completion}'
@step(u'{n_prompt:d} prompt tokens are processed')
def step_impl(context, n_prompt):
assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}"
@step(u'a user prompt {user_prompt}') @step(u'a user prompt {user_prompt}')
@ -722,7 +740,8 @@ async def oai_chat_completions(user_prompt,
completion_response = { completion_response = {
'content': '', 'content': '',
'timings': { 'timings': {
'predicted_n': 0 'predicted_n': 0,
'prompt_n': 0
} }
} }
if async_client: if async_client:
@ -763,7 +782,8 @@ async def oai_chat_completions(user_prompt,
completion_response = { completion_response = {
'content': chat_completion_raw['choices'][0]['message'], 'content': chat_completion_raw['choices'][0]['message'],
'timings': { 'timings': {
'predicted_n': chat_completion_raw['usage']['completion_tokens'] 'predicted_n': chat_completion_raw['usage']['completion_tokens'],
'prompt_n': chat_completion_raw['usage']['prompt_tokens']
} }
} }
else: else:
@ -792,13 +812,16 @@ async def oai_chat_completions(user_prompt,
if 'content' in delta: if 'content' in delta:
completion_response['content'] += delta['content'] completion_response['content'] += delta['content']
completion_response['timings']['predicted_n'] += 1 completion_response['timings']['predicted_n'] += 1
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
else: else:
assert len(chat_completion.choices) == 1 assert len(chat_completion.choices) == 1
completion_response = { completion_response = {
'content': chat_completion.choices[0].message.content, 'content': chat_completion.choices[0].message.content,
'timings': { 'timings': {
'predicted_n': chat_completion.usage.completion_tokens 'predicted_n': chat_completion.usage.completion_tokens,
} 'prompt_n': chat_completion.usage.prompt_tokens
},
'truncated': chat_completion.choices[0].finish_reason != 'stop'
} }
if debug: if debug:
print("OAI response formatted to llama.cpp:", completion_response) print("OAI response formatted to llama.cpp:", completion_response)