server-tests : add more type annotations

This commit is contained in:
Francis Couture-Harpin 2024-07-06 19:27:38 -04:00
parent fbf4a85868
commit 71b50a148c

View File

@ -10,7 +10,7 @@ import time
from collections.abc import Sequence
from contextlib import closing
from re import RegexFlag
from typing import cast
from typing import Any, Literal, cast
import aiohttp
import numpy as np
@ -23,7 +23,7 @@ from prometheus_client import parser
# pyright: reportRedeclaration=false
@step("a server listening on {server_fqdn}:{server_port}")
def step_server_config(context, server_fqdn, server_port):
def step_server_config(context, server_fqdn: str, server_port: str):
context.server_fqdn = server_fqdn
context.server_port = int(server_port)
context.n_threads = None
@ -77,34 +77,34 @@ def step_server_config(context, server_fqdn, server_port):
@step('a model file {hf_file} from HF repo {hf_repo}')
def step_download_hf_model(context, hf_file, hf_repo):
def step_download_hf_model(context, hf_file: str, hf_repo: str):
context.model_hf_repo = hf_repo
context.model_hf_file = hf_file
context.model_file = os.path.basename(hf_file)
@step('a model file {model_file}')
def step_model_file(context, model_file):
def step_model_file(context, model_file: str):
context.model_file = model_file
@step('a model url {model_url}')
def step_model_url(context, model_url):
def step_model_url(context, model_url: str):
context.model_url = model_url
@step('a model alias {model_alias}')
def step_model_alias(context, model_alias):
def step_model_alias(context, model_alias: str):
context.model_alias = model_alias
@step('{seed:d} as server seed')
def step_seed(context, seed):
def step_seed(context, seed: int):
context.server_seed = seed
@step('{ngl:d} GPU offloaded layers')
def step_n_gpu_layer(context, ngl):
def step_n_gpu_layer(context, ngl: int):
if 'N_GPU_LAYERS' in os.environ:
new_ngl = int(os.environ['N_GPU_LAYERS'])
if context.debug:
@ -114,37 +114,37 @@ def step_n_gpu_layer(context, ngl):
@step('{n_threads:d} threads')
def step_n_threads(context, n_threads):
def step_n_threads(context, n_threads: int):
context.n_thread = n_threads
@step('{draft:d} as draft')
def step_draft(context, draft):
def step_draft(context, draft: int):
context.draft = draft
@step('{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx):
def step_n_ctx(context, n_ctx: int):
context.n_ctx = n_ctx
@step('{n_slots:d} slots')
def step_n_slots(context, n_slots):
def step_n_slots(context, n_slots: int):
context.n_slots = n_slots
@step('{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict):
def step_server_n_predict(context, n_predict: int):
context.n_server_predict = n_predict
@step('{slot_save_path} as slot save path')
def step_slot_save_path(context, slot_save_path):
def step_slot_save_path(context, slot_save_path: str):
context.slot_save_path = slot_save_path
@step('using slot id {id_slot:d}')
def step_id_slot(context, id_slot):
def step_id_slot(context, id_slot: int):
context.id_slot = id_slot
@ -194,7 +194,7 @@ def step_start_server(context):
@step("the server is {expecting_status}")
@async_run_until_complete
async def step_wait_for_the_server_to_be_started(context, expecting_status):
async def step_wait_for_the_server_to_be_started(context, expecting_status: str):
match expecting_status:
case 'healthy':
await wait_for_health_status(context, context.base_url, 200, 'ok',
@ -224,7 +224,7 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
@step('all slots are {expected_slot_status_string}')
@async_run_until_complete
async def step_all_slots_status(context, expected_slot_status_string):
async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
match expected_slot_status_string:
case 'idle':
expected_slot_status = 0
@ -240,7 +240,7 @@ async def step_all_slots_status(context, expected_slot_status_string):
@step('a completion request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error):
async def step_request_completion(context, api_error: Literal['raised'] | str):
expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
@ -865,7 +865,7 @@ async def request_completion(prompt,
id_slot=None,
expect_api_error=None,
user_api_key=None,
temperature=None):
temperature=None) -> int | dict[str, Any]:
if debug:
print(f"Sending completion request: {prompt}")
origin = "my.super.domain"
@ -913,7 +913,7 @@ async def oai_chat_completions(user_prompt,
enable_streaming=None,
response_format=None,
user_api_key=None,
expect_api_error=None):
expect_api_error=None) -> int | dict[str, Any]:
if debug:
print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key
@ -1035,7 +1035,7 @@ async def oai_chat_completions(user_prompt,
return completion_response
async def request_embedding(content, seed, base_url=None):
async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding',
json={
@ -1048,7 +1048,7 @@ async def request_embedding(content, seed, base_url=None):
async def request_oai_embeddings(input, seed,
base_url=None, user_api_key=None,
model=None, async_client=False):
model=None, async_client=False) -> list[list[float]]:
# openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
@ -1086,7 +1086,7 @@ async def request_oai_embeddings(input, seed,
input=input,
)
return oai_embeddings.data
return [e.embedding for e in oai_embeddings.data]
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):