server: continue to update other slots on embedding concurrent request (#5699)

* server: #5655 - continue to update other slots on embedding concurrent request.

* server: tests: add multi users embeddings as fixed

* server: tests: adding OAI compatible embedding concurrent endpoint

* server: tests: adding OAI compatible embedding with multiple inputs
This commit is contained in:
Pierrick Hymbert 2024-02-24 19:16:04 +01:00 committed by GitHub
parent 4c4cb30736
commit 9e359a4f47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 168 additions and 78 deletions

View File

@ -1836,7 +1836,7 @@ struct llama_server_context
send_embedding(slot); send_embedding(slot);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
return true; continue;
} }
completion_token_output result; completion_token_output result;

View File

@ -1,36 +1,4 @@
# List of ongoing issues # List of ongoing issues
@bug @bug
Feature: Issues Feature: Issues
# Issue #5655 # No confirmed issue at the moment
Scenario: Multi users embeddings
Given a server listening on localhost:8080
And a model file stories260K.gguf
And a model alias tinyllama-2
And 42 as server seed
And 64 KV cache size
And 2 slots
And continuous batching
And embeddings extraction
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated

View File

@ -8,6 +8,7 @@ Feature: Parallel
And 42 as server seed And 42 as server seed
And 64 KV cache size And 64 KV cache size
And 2 slots And 2 slots
And embeddings extraction
And continuous batching And continuous batching
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
@ -75,3 +76,48 @@ Feature: Parallel
Then the server is busy Then the server is busy
Then the server is idle Then the server is idle
Then all prompts are predicted Then all prompts are predicted
Scenario: Multi users embeddings
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated
Scenario: Multi users OAI compatibility embeddings
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
And a prompt:
"""
What is the biggest US city ?
"""
And a prompt:
"""
What is the capital of Bulgaria ?
"""
And a model tinyllama-2
Given concurrent OAI embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated

View File

@ -60,6 +60,19 @@ Feature: llama.cpp server
""" """
Then embeddings are generated Then embeddings are generated
Scenario: OAI Embeddings compatibility with multiple inputs
Given a model tinyllama-2
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize
When tokenizing: When tokenizing:

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import collections
import json import json
import os import os
import re import re
@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt):
@step(u'concurrent completion requests') @step(u'concurrent completion requests')
@async_run_until_complete() @async_run_until_complete()
async def step_concurrent_completion_requests(context): async def step_concurrent_completion_requests(context):
await concurrent_completion_requests(context, await concurrent_requests(context,
request_completion, request_completion,
# prompt is inserted automatically # prompt is inserted automatically
context.base_url, context.base_url,
debug=context.debug, debug=context.debug,
n_predict=context.n_predict if hasattr(context, 'n_predict') else None, n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
server_seed=context.server_seed if hasattr(context, 'server_seed') else None, server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key if hasattr(context, user_api_key=context.user_api_key if hasattr(context,
'user_api_key') else None) 'user_api_key') else None)
@step(u'concurrent OAI completions requests') @step(u'concurrent OAI completions requests')
@async_run_until_complete @async_run_until_complete
async def step_oai_chat_completions(context): async def step_oai_chat_completions(context):
await concurrent_completion_requests(context, oai_chat_completions, await concurrent_requests(context, oai_chat_completions,
# user_prompt is inserted automatically # user_prompt is inserted automatically
context.system_prompt, context.system_prompt,
context.base_url, context.base_url,
True, # async_client True, # async_client
model=context.model model=context.model
if hasattr(context, 'model') else None, if hasattr(context, 'model') else None,
n_predict=context.n_predict n_predict=context.n_predict
if hasattr(context, 'n_predict') else None, if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed server_seed=context.server_seed
if hasattr(context, 'server_seed') else None, if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@step(u'all prompts are predicted') @step(u'all prompts are predicted')
@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
@step(u'embeddings are computed for') @step(u'embeddings are computed for')
@async_run_until_complete @async_run_until_complete
async def step_compute_embedding(context): async def step_compute_embedding(context):
content = context.text context.embeddings = await request_embedding(context.text, base_url=context.base_url)
base_url = context.base_url
context.embeddings = await request_embedding(content, base_url)
@step(u'embeddings are generated') @step(u'embeddings are generated')
def step_assert_embeddings(context): def step_assert_embeddings(context):
assert_embeddings(context.embeddings) if len(context.prompts) == 0:
assert_embeddings(context.embeddings)
else:
assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
f"context.prompts={context.prompts}\n"
f"context.embeddings={context.embeddings}")
for embedding in context.embeddings:
context.prompts.pop()
assert_embeddings(embedding)
@step(u'an OAI compatible embeddings computation request for') @step(u'an OAI compatible embeddings computation request for')
def step_oai_compute_embedding(context): @async_run_until_complete
openai.api_key = 'nope' # openai client always expects an api_keu async def step_oai_compute_embeddings(context):
if context.user_api_key is not None: context.embeddings = await request_oai_embeddings(context.text,
openai.api_key = context.user_api_key base_url=context.base_url,
openai.api_base = f'{context.base_url}/v1' user_api_key=context.user_api_key,
embeddings = openai.Embedding.create( model=context.model)
model=context.model,
input=context.text,
) @step(u'an OAI compatible embeddings computation request for multiple inputs')
context.embeddings = embeddings @async_run_until_complete
async def step_oai_compute_embeddings_multiple_inputs(context):
context.embeddings = await request_oai_embeddings(context.prompts,
base_url=context.base_url,
user_api_key=context.user_api_key,
model=context.model)
@step(u'concurrent embedding requests') @step(u'concurrent embedding requests')
@async_run_until_complete() @async_run_until_complete()
async def step_concurrent_embedding_requests(context): async def step_concurrent_embedding_requests(context):
await concurrent_completion_requests(context, await concurrent_requests(context,
request_embedding, request_embedding,
# prompt is inserted automatically # prompt is inserted automatically
context.base_url) base_url=context.base_url)
@step(u'concurrent OAI embedding requests')
@async_run_until_complete()
async def step_concurrent_oai_embedding_requests(context):
await concurrent_requests(context,
request_oai_embeddings,
# prompt is inserted automatically
base_url=context.base_url,
async_client=True,
model=context.model)
@step(u'all embeddings are generated') @step(u'all embeddings are generated')
@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
assert context.options_response.headers[cors_header] == cors_header_value assert context.options_response.headers[cors_header] == cors_header_value
async def concurrent_completion_requests(context, f_completion, *args, **kwargs): async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts) n_prompts = len(context.prompts)
if context.debug: if context.debug:
print(f"starting {n_prompts} concurrent completion requests...") print(f"starting {n_prompts} concurrent completion requests...")
@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt,
return completion_response return completion_response
async def request_embedding(content, base_url): async def request_embedding(content, base_url=None):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding', async with session.post(f'{base_url}/embedding',
json={ json={
@ -576,6 +599,46 @@ async def request_embedding(content, base_url):
return response_json['embedding'] return response_json['embedding']
async def request_oai_embeddings(input,
base_url=None, user_api_key=None,
model=None, async_client=False):
# openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
origin = 'llama.cpp'
if user_api_key is not None:
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/v1/embeddings',
json={
"input": input,
"model": model,
},
headers=headers) as response:
assert response.status == 200, f"received status code not expected: {response.status}"
assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
response_json = await response.json()
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
assert response_json['object'] == 'list'
return response_json['data']
else:
openai.api_key = user_api_key
openai.api_base = f'{base_url}/v1'
oai_embeddings = openai.Embedding.create(
model=model,
input=input,
)
if isinstance(input, collections.abc.Sequence):
embeddings = []
for an_oai_embeddings in oai_embeddings.data:
embeddings.append(an_oai_embeddings.embedding)
else:
embeddings = oai_embeddings.data.embedding
return embeddings
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
content = completion_response['content'] content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n'] n_predicted = completion_response['timings']['predicted_n']