mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 21:07:06 +01:00
server-tests : add more type annotations
This commit is contained in:
parent
fbf4a85868
commit
71b50a148c
@ -10,7 +10,7 @@ import time
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from re import RegexFlag
|
from re import RegexFlag
|
||||||
from typing import cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,7 +23,7 @@ from prometheus_client import parser
|
|||||||
# pyright: reportRedeclaration=false
|
# pyright: reportRedeclaration=false
|
||||||
|
|
||||||
@step("a server listening on {server_fqdn}:{server_port}")
|
@step("a server listening on {server_fqdn}:{server_port}")
|
||||||
def step_server_config(context, server_fqdn, server_port):
|
def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.server_fqdn = server_fqdn
|
context.server_fqdn = server_fqdn
|
||||||
context.server_port = int(server_port)
|
context.server_port = int(server_port)
|
||||||
context.n_threads = None
|
context.n_threads = None
|
||||||
@ -77,34 +77,34 @@ def step_server_config(context, server_fqdn, server_port):
|
|||||||
|
|
||||||
|
|
||||||
@step('a model file {hf_file} from HF repo {hf_repo}')
|
@step('a model file {hf_file} from HF repo {hf_repo}')
|
||||||
def step_download_hf_model(context, hf_file, hf_repo):
|
def step_download_hf_model(context, hf_file: str, hf_repo: str):
|
||||||
context.model_hf_repo = hf_repo
|
context.model_hf_repo = hf_repo
|
||||||
context.model_hf_file = hf_file
|
context.model_hf_file = hf_file
|
||||||
context.model_file = os.path.basename(hf_file)
|
context.model_file = os.path.basename(hf_file)
|
||||||
|
|
||||||
|
|
||||||
@step('a model file {model_file}')
|
@step('a model file {model_file}')
|
||||||
def step_model_file(context, model_file):
|
def step_model_file(context, model_file: str):
|
||||||
context.model_file = model_file
|
context.model_file = model_file
|
||||||
|
|
||||||
|
|
||||||
@step('a model url {model_url}')
|
@step('a model url {model_url}')
|
||||||
def step_model_url(context, model_url):
|
def step_model_url(context, model_url: str):
|
||||||
context.model_url = model_url
|
context.model_url = model_url
|
||||||
|
|
||||||
|
|
||||||
@step('a model alias {model_alias}')
|
@step('a model alias {model_alias}')
|
||||||
def step_model_alias(context, model_alias):
|
def step_model_alias(context, model_alias: str):
|
||||||
context.model_alias = model_alias
|
context.model_alias = model_alias
|
||||||
|
|
||||||
|
|
||||||
@step('{seed:d} as server seed')
|
@step('{seed:d} as server seed')
|
||||||
def step_seed(context, seed):
|
def step_seed(context, seed: int):
|
||||||
context.server_seed = seed
|
context.server_seed = seed
|
||||||
|
|
||||||
|
|
||||||
@step('{ngl:d} GPU offloaded layers')
|
@step('{ngl:d} GPU offloaded layers')
|
||||||
def step_n_gpu_layer(context, ngl):
|
def step_n_gpu_layer(context, ngl: int):
|
||||||
if 'N_GPU_LAYERS' in os.environ:
|
if 'N_GPU_LAYERS' in os.environ:
|
||||||
new_ngl = int(os.environ['N_GPU_LAYERS'])
|
new_ngl = int(os.environ['N_GPU_LAYERS'])
|
||||||
if context.debug:
|
if context.debug:
|
||||||
@ -114,37 +114,37 @@ def step_n_gpu_layer(context, ngl):
|
|||||||
|
|
||||||
|
|
||||||
@step('{n_threads:d} threads')
|
@step('{n_threads:d} threads')
|
||||||
def step_n_threads(context, n_threads):
|
def step_n_threads(context, n_threads: int):
|
||||||
context.n_thread = n_threads
|
context.n_thread = n_threads
|
||||||
|
|
||||||
|
|
||||||
@step('{draft:d} as draft')
|
@step('{draft:d} as draft')
|
||||||
def step_draft(context, draft):
|
def step_draft(context, draft: int):
|
||||||
context.draft = draft
|
context.draft = draft
|
||||||
|
|
||||||
|
|
||||||
@step('{n_ctx:d} KV cache size')
|
@step('{n_ctx:d} KV cache size')
|
||||||
def step_n_ctx(context, n_ctx):
|
def step_n_ctx(context, n_ctx: int):
|
||||||
context.n_ctx = n_ctx
|
context.n_ctx = n_ctx
|
||||||
|
|
||||||
|
|
||||||
@step('{n_slots:d} slots')
|
@step('{n_slots:d} slots')
|
||||||
def step_n_slots(context, n_slots):
|
def step_n_slots(context, n_slots: int):
|
||||||
context.n_slots = n_slots
|
context.n_slots = n_slots
|
||||||
|
|
||||||
|
|
||||||
@step('{n_predict:d} server max tokens to predict')
|
@step('{n_predict:d} server max tokens to predict')
|
||||||
def step_server_n_predict(context, n_predict):
|
def step_server_n_predict(context, n_predict: int):
|
||||||
context.n_server_predict = n_predict
|
context.n_server_predict = n_predict
|
||||||
|
|
||||||
|
|
||||||
@step('{slot_save_path} as slot save path')
|
@step('{slot_save_path} as slot save path')
|
||||||
def step_slot_save_path(context, slot_save_path):
|
def step_slot_save_path(context, slot_save_path: str):
|
||||||
context.slot_save_path = slot_save_path
|
context.slot_save_path = slot_save_path
|
||||||
|
|
||||||
|
|
||||||
@step('using slot id {id_slot:d}')
|
@step('using slot id {id_slot:d}')
|
||||||
def step_id_slot(context, id_slot):
|
def step_id_slot(context, id_slot: int):
|
||||||
context.id_slot = id_slot
|
context.id_slot = id_slot
|
||||||
|
|
||||||
|
|
||||||
@ -194,7 +194,7 @@ def step_start_server(context):
|
|||||||
|
|
||||||
@step("the server is {expecting_status}")
|
@step("the server is {expecting_status}")
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
async def step_wait_for_the_server_to_be_started(context, expecting_status: str):
|
||||||
match expecting_status:
|
match expecting_status:
|
||||||
case 'healthy':
|
case 'healthy':
|
||||||
await wait_for_health_status(context, context.base_url, 200, 'ok',
|
await wait_for_health_status(context, context.base_url, 200, 'ok',
|
||||||
@ -224,7 +224,7 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
|||||||
|
|
||||||
@step('all slots are {expected_slot_status_string}')
|
@step('all slots are {expected_slot_status_string}')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_all_slots_status(context, expected_slot_status_string):
|
async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
|
||||||
match expected_slot_status_string:
|
match expected_slot_status_string:
|
||||||
case 'idle':
|
case 'idle':
|
||||||
expected_slot_status = 0
|
expected_slot_status = 0
|
||||||
@ -240,7 +240,7 @@ async def step_all_slots_status(context, expected_slot_status_string):
|
|||||||
|
|
||||||
@step('a completion request with {api_error} api error')
|
@step('a completion request with {api_error} api error')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_request_completion(context, api_error):
|
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised'
|
||||||
seeds = await completions_seed(context, num_seeds=1)
|
seeds = await completions_seed(context, num_seeds=1)
|
||||||
completion = await request_completion(context.prompts.pop(),
|
completion = await request_completion(context.prompts.pop(),
|
||||||
@ -865,7 +865,7 @@ async def request_completion(prompt,
|
|||||||
id_slot=None,
|
id_slot=None,
|
||||||
expect_api_error=None,
|
expect_api_error=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
temperature=None):
|
temperature=None) -> int | dict[str, Any]:
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Sending completion request: {prompt}")
|
print(f"Sending completion request: {prompt}")
|
||||||
origin = "my.super.domain"
|
origin = "my.super.domain"
|
||||||
@ -913,7 +913,7 @@ async def oai_chat_completions(user_prompt,
|
|||||||
enable_streaming=None,
|
enable_streaming=None,
|
||||||
response_format=None,
|
response_format=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
expect_api_error=None):
|
expect_api_error=None) -> int | dict[str, Any]:
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||||
# openai client always expects an api key
|
# openai client always expects an api key
|
||||||
@ -1035,7 +1035,7 @@ async def oai_chat_completions(user_prompt,
|
|||||||
return completion_response
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
async def request_embedding(content, seed, base_url=None):
|
async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f'{base_url}/embedding',
|
async with session.post(f'{base_url}/embedding',
|
||||||
json={
|
json={
|
||||||
@ -1048,7 +1048,7 @@ async def request_embedding(content, seed, base_url=None):
|
|||||||
|
|
||||||
async def request_oai_embeddings(input, seed,
|
async def request_oai_embeddings(input, seed,
|
||||||
base_url=None, user_api_key=None,
|
base_url=None, user_api_key=None,
|
||||||
model=None, async_client=False):
|
model=None, async_client=False) -> list[list[float]]:
|
||||||
# openai client always expects an api_key
|
# openai client always expects an api_key
|
||||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||||
if async_client:
|
if async_client:
|
||||||
@ -1086,7 +1086,7 @@ async def request_oai_embeddings(input, seed,
|
|||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
|
||||||
return oai_embeddings.data
|
return [e.embedding for e in oai_embeddings.data]
|
||||||
|
|
||||||
|
|
||||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user