mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
server: continue to update other slots on embedding concurrent request (#5699)
* server: #5655 - continue to update other slots on embedding concurrent request. * server: tests: add multi users embeddings as fixed * server: tests: adding OAI compatible embedding concurrent endpoint * server: tests: adding OAI compatible embedding with multiple inputs
This commit is contained in:
parent
4c4cb30736
commit
9e359a4f47
@ -1836,7 +1836,7 @@ struct llama_server_context
|
|||||||
send_embedding(slot);
|
send_embedding(slot);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
return true;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
|
@ -1,36 +1,4 @@
|
|||||||
# List of ongoing issues
|
# List of ongoing issues
|
||||||
@bug
|
@bug
|
||||||
Feature: Issues
|
Feature: Issues
|
||||||
# Issue #5655
|
# No confirmed issue at the moment
|
||||||
Scenario: Multi users embeddings
|
|
||||||
Given a server listening on localhost:8080
|
|
||||||
And a model file stories260K.gguf
|
|
||||||
And a model alias tinyllama-2
|
|
||||||
And 42 as server seed
|
|
||||||
And 64 KV cache size
|
|
||||||
And 2 slots
|
|
||||||
And continuous batching
|
|
||||||
And embeddings extraction
|
|
||||||
Then the server is starting
|
|
||||||
Then the server is healthy
|
|
||||||
|
|
||||||
Given a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long story about AI.
|
|
||||||
"""
|
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write another very long music lyrics.
|
|
||||||
"""
|
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long poem.
|
|
||||||
"""
|
|
||||||
And a prompt:
|
|
||||||
"""
|
|
||||||
Write a very long joke.
|
|
||||||
"""
|
|
||||||
Given concurrent embedding requests
|
|
||||||
Then the server is busy
|
|
||||||
Then the server is idle
|
|
||||||
Then all embeddings are generated
|
|
||||||
|
@ -8,6 +8,7 @@ Feature: Parallel
|
|||||||
And 42 as server seed
|
And 42 as server seed
|
||||||
And 64 KV cache size
|
And 64 KV cache size
|
||||||
And 2 slots
|
And 2 slots
|
||||||
|
And embeddings extraction
|
||||||
And continuous batching
|
And continuous batching
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
@ -75,3 +76,48 @@ Feature: Parallel
|
|||||||
Then the server is busy
|
Then the server is busy
|
||||||
Then the server is idle
|
Then the server is idle
|
||||||
Then all prompts are predicted
|
Then all prompts are predicted
|
||||||
|
|
||||||
|
Scenario: Multi users embeddings
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Write a very long story about AI.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write another very long music lyrics.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write a very long poem.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write a very long joke.
|
||||||
|
"""
|
||||||
|
Given concurrent embedding requests
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
Then all embeddings are generated
|
||||||
|
|
||||||
|
Scenario: Multi users OAI compatibility embeddings
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
In which country Paris is located ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Is Madrid the capital of Spain ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
What is the biggest US city ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
What is the capital of Bulgaria ?
|
||||||
|
"""
|
||||||
|
And a model tinyllama-2
|
||||||
|
Given concurrent OAI embedding requests
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
Then all embeddings are generated
|
||||||
|
@ -60,6 +60,19 @@ Feature: llama.cpp server
|
|||||||
"""
|
"""
|
||||||
Then embeddings are generated
|
Then embeddings are generated
|
||||||
|
|
||||||
|
Scenario: OAI Embeddings compatibility with multiple inputs
|
||||||
|
Given a model tinyllama-2
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
In which country Paris is located ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Is Madrid the capital of Spain ?
|
||||||
|
"""
|
||||||
|
When an OAI compatible embeddings computation request for multiple inputs
|
||||||
|
Then embeddings are generated
|
||||||
|
|
||||||
|
|
||||||
Scenario: Tokenize / Detokenize
|
Scenario: Tokenize / Detokenize
|
||||||
When tokenizing:
|
When tokenizing:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import collections
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt):
|
|||||||
@step(u'concurrent completion requests')
|
@step(u'concurrent completion requests')
|
||||||
@async_run_until_complete()
|
@async_run_until_complete()
|
||||||
async def step_concurrent_completion_requests(context):
|
async def step_concurrent_completion_requests(context):
|
||||||
await concurrent_completion_requests(context,
|
await concurrent_requests(context,
|
||||||
request_completion,
|
request_completion,
|
||||||
# prompt is inserted automatically
|
# prompt is inserted automatically
|
||||||
context.base_url,
|
context.base_url,
|
||||||
debug=context.debug,
|
debug=context.debug,
|
||||||
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||||
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
||||||
user_api_key=context.user_api_key if hasattr(context,
|
user_api_key=context.user_api_key if hasattr(context,
|
||||||
'user_api_key') else None)
|
'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent OAI completions requests')
|
@step(u'concurrent OAI completions requests')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_oai_chat_completions(context):
|
async def step_oai_chat_completions(context):
|
||||||
await concurrent_completion_requests(context, oai_chat_completions,
|
await concurrent_requests(context, oai_chat_completions,
|
||||||
# user_prompt is inserted automatically
|
# user_prompt is inserted automatically
|
||||||
context.system_prompt,
|
context.system_prompt,
|
||||||
context.base_url,
|
context.base_url,
|
||||||
True, # async_client
|
True, # async_client
|
||||||
model=context.model
|
model=context.model
|
||||||
if hasattr(context, 'model') else None,
|
if hasattr(context, 'model') else None,
|
||||||
n_predict=context.n_predict
|
n_predict=context.n_predict
|
||||||
if hasattr(context, 'n_predict') else None,
|
if hasattr(context, 'n_predict') else None,
|
||||||
enable_streaming=context.enable_streaming
|
enable_streaming=context.enable_streaming
|
||||||
if hasattr(context, 'enable_streaming') else None,
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
server_seed=context.server_seed
|
server_seed=context.server_seed
|
||||||
if hasattr(context, 'server_seed') else None,
|
if hasattr(context, 'server_seed') else None,
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None)
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
@step(u'all prompts are predicted')
|
@step(u'all prompts are predicted')
|
||||||
@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
|||||||
@step(u'embeddings are computed for')
|
@step(u'embeddings are computed for')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_compute_embedding(context):
|
async def step_compute_embedding(context):
|
||||||
content = context.text
|
context.embeddings = await request_embedding(context.text, base_url=context.base_url)
|
||||||
base_url = context.base_url
|
|
||||||
context.embeddings = await request_embedding(content, base_url)
|
|
||||||
|
|
||||||
|
|
||||||
@step(u'embeddings are generated')
|
@step(u'embeddings are generated')
|
||||||
def step_assert_embeddings(context):
|
def step_assert_embeddings(context):
|
||||||
assert_embeddings(context.embeddings)
|
if len(context.prompts) == 0:
|
||||||
|
assert_embeddings(context.embeddings)
|
||||||
|
else:
|
||||||
|
assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
|
||||||
|
f"context.prompts={context.prompts}\n"
|
||||||
|
f"context.embeddings={context.embeddings}")
|
||||||
|
for embedding in context.embeddings:
|
||||||
|
context.prompts.pop()
|
||||||
|
assert_embeddings(embedding)
|
||||||
|
|
||||||
|
|
||||||
@step(u'an OAI compatible embeddings computation request for')
|
@step(u'an OAI compatible embeddings computation request for')
|
||||||
def step_oai_compute_embedding(context):
|
@async_run_until_complete
|
||||||
openai.api_key = 'nope' # openai client always expects an api_keu
|
async def step_oai_compute_embeddings(context):
|
||||||
if context.user_api_key is not None:
|
context.embeddings = await request_oai_embeddings(context.text,
|
||||||
openai.api_key = context.user_api_key
|
base_url=context.base_url,
|
||||||
openai.api_base = f'{context.base_url}/v1'
|
user_api_key=context.user_api_key,
|
||||||
embeddings = openai.Embedding.create(
|
model=context.model)
|
||||||
model=context.model,
|
|
||||||
input=context.text,
|
|
||||||
)
|
@step(u'an OAI compatible embeddings computation request for multiple inputs')
|
||||||
context.embeddings = embeddings
|
@async_run_until_complete
|
||||||
|
async def step_oai_compute_embeddings_multiple_inputs(context):
|
||||||
|
context.embeddings = await request_oai_embeddings(context.prompts,
|
||||||
|
base_url=context.base_url,
|
||||||
|
user_api_key=context.user_api_key,
|
||||||
|
model=context.model)
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent embedding requests')
|
@step(u'concurrent embedding requests')
|
||||||
@async_run_until_complete()
|
@async_run_until_complete()
|
||||||
async def step_concurrent_embedding_requests(context):
|
async def step_concurrent_embedding_requests(context):
|
||||||
await concurrent_completion_requests(context,
|
await concurrent_requests(context,
|
||||||
request_embedding,
|
request_embedding,
|
||||||
# prompt is inserted automatically
|
# prompt is inserted automatically
|
||||||
context.base_url)
|
base_url=context.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'concurrent OAI embedding requests')
|
||||||
|
@async_run_until_complete()
|
||||||
|
async def step_concurrent_oai_embedding_requests(context):
|
||||||
|
await concurrent_requests(context,
|
||||||
|
request_oai_embeddings,
|
||||||
|
# prompt is inserted automatically
|
||||||
|
base_url=context.base_url,
|
||||||
|
async_client=True,
|
||||||
|
model=context.model)
|
||||||
|
|
||||||
|
|
||||||
@step(u'all embeddings are generated')
|
@step(u'all embeddings are generated')
|
||||||
@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
|
|||||||
assert context.options_response.headers[cors_header] == cors_header_value
|
assert context.options_response.headers[cors_header] == cors_header_value
|
||||||
|
|
||||||
|
|
||||||
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
async def concurrent_requests(context, f_completion, *args, **kwargs):
|
||||||
n_prompts = len(context.prompts)
|
n_prompts = len(context.prompts)
|
||||||
if context.debug:
|
if context.debug:
|
||||||
print(f"starting {n_prompts} concurrent completion requests...")
|
print(f"starting {n_prompts} concurrent completion requests...")
|
||||||
@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt,
|
|||||||
return completion_response
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
async def request_embedding(content, base_url):
|
async def request_embedding(content, base_url=None):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f'{base_url}/embedding',
|
async with session.post(f'{base_url}/embedding',
|
||||||
json={
|
json={
|
||||||
@ -576,6 +599,46 @@ async def request_embedding(content, base_url):
|
|||||||
return response_json['embedding']
|
return response_json['embedding']
|
||||||
|
|
||||||
|
|
||||||
|
async def request_oai_embeddings(input,
|
||||||
|
base_url=None, user_api_key=None,
|
||||||
|
model=None, async_client=False):
|
||||||
|
# openai client always expects an api_key
|
||||||
|
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||||
|
if async_client:
|
||||||
|
origin = 'llama.cpp'
|
||||||
|
if user_api_key is not None:
|
||||||
|
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(f'{base_url}/v1/embeddings',
|
||||||
|
json={
|
||||||
|
"input": input,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
headers=headers) as response:
|
||||||
|
assert response.status == 200, f"received status code not expected: {response.status}"
|
||||||
|
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||||
|
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
|
||||||
|
response_json = await response.json()
|
||||||
|
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
|
||||||
|
assert response_json['object'] == 'list'
|
||||||
|
return response_json['data']
|
||||||
|
else:
|
||||||
|
openai.api_key = user_api_key
|
||||||
|
openai.api_base = f'{base_url}/v1'
|
||||||
|
oai_embeddings = openai.Embedding.create(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(input, collections.abc.Sequence):
|
||||||
|
embeddings = []
|
||||||
|
for an_oai_embeddings in oai_embeddings.data:
|
||||||
|
embeddings.append(an_oai_embeddings.embedding)
|
||||||
|
else:
|
||||||
|
embeddings = oai_embeddings.data.embedding
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||||
content = completion_response['content']
|
content = completion_response['content']
|
||||||
n_predicted = completion_response['timings']['predicted_n']
|
n_predicted = completion_response['timings']['predicted_n']
|
||||||
|
Loading…
Reference in New Issue
Block a user