server-tests : add more type annotations

This commit is contained in:
Francis Couture-Harpin 2024-07-06 19:27:38 -04:00
parent fbf4a85868
commit 71b50a148c

View File

@ -10,7 +10,7 @@ import time
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import closing from contextlib import closing
from re import RegexFlag from re import RegexFlag
from typing import cast from typing import Any, Literal, cast
import aiohttp import aiohttp
import numpy as np import numpy as np
@ -23,7 +23,7 @@ from prometheus_client import parser
# pyright: reportRedeclaration=false # pyright: reportRedeclaration=false
@step("a server listening on {server_fqdn}:{server_port}") @step("a server listening on {server_fqdn}:{server_port}")
def step_server_config(context, server_fqdn, server_port): def step_server_config(context, server_fqdn: str, server_port: str):
context.server_fqdn = server_fqdn context.server_fqdn = server_fqdn
context.server_port = int(server_port) context.server_port = int(server_port)
context.n_threads = None context.n_threads = None
@ -77,34 +77,34 @@ def step_server_config(context, server_fqdn, server_port):
@step('a model file {hf_file} from HF repo {hf_repo}') @step('a model file {hf_file} from HF repo {hf_repo}')
def step_download_hf_model(context, hf_file, hf_repo): def step_download_hf_model(context, hf_file: str, hf_repo: str):
context.model_hf_repo = hf_repo context.model_hf_repo = hf_repo
context.model_hf_file = hf_file context.model_hf_file = hf_file
context.model_file = os.path.basename(hf_file) context.model_file = os.path.basename(hf_file)
@step('a model file {model_file}') @step('a model file {model_file}')
def step_model_file(context, model_file): def step_model_file(context, model_file: str):
context.model_file = model_file context.model_file = model_file
@step('a model url {model_url}') @step('a model url {model_url}')
def step_model_url(context, model_url): def step_model_url(context, model_url: str):
context.model_url = model_url context.model_url = model_url
@step('a model alias {model_alias}') @step('a model alias {model_alias}')
def step_model_alias(context, model_alias): def step_model_alias(context, model_alias: str):
context.model_alias = model_alias context.model_alias = model_alias
@step('{seed:d} as server seed') @step('{seed:d} as server seed')
def step_seed(context, seed): def step_seed(context, seed: int):
context.server_seed = seed context.server_seed = seed
@step('{ngl:d} GPU offloaded layers') @step('{ngl:d} GPU offloaded layers')
def step_n_gpu_layer(context, ngl): def step_n_gpu_layer(context, ngl: int):
if 'N_GPU_LAYERS' in os.environ: if 'N_GPU_LAYERS' in os.environ:
new_ngl = int(os.environ['N_GPU_LAYERS']) new_ngl = int(os.environ['N_GPU_LAYERS'])
if context.debug: if context.debug:
@ -114,37 +114,37 @@ def step_n_gpu_layer(context, ngl):
@step('{n_threads:d} threads') @step('{n_threads:d} threads')
def step_n_threads(context, n_threads): def step_n_threads(context, n_threads: int):
context.n_thread = n_threads context.n_thread = n_threads
@step('{draft:d} as draft') @step('{draft:d} as draft')
def step_draft(context, draft): def step_draft(context, draft: int):
context.draft = draft context.draft = draft
@step('{n_ctx:d} KV cache size') @step('{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx): def step_n_ctx(context, n_ctx: int):
context.n_ctx = n_ctx context.n_ctx = n_ctx
@step('{n_slots:d} slots') @step('{n_slots:d} slots')
def step_n_slots(context, n_slots): def step_n_slots(context, n_slots: int):
context.n_slots = n_slots context.n_slots = n_slots
@step('{n_predict:d} server max tokens to predict') @step('{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict): def step_server_n_predict(context, n_predict: int):
context.n_server_predict = n_predict context.n_server_predict = n_predict
@step('{slot_save_path} as slot save path') @step('{slot_save_path} as slot save path')
def step_slot_save_path(context, slot_save_path): def step_slot_save_path(context, slot_save_path: str):
context.slot_save_path = slot_save_path context.slot_save_path = slot_save_path
@step('using slot id {id_slot:d}') @step('using slot id {id_slot:d}')
def step_id_slot(context, id_slot): def step_id_slot(context, id_slot: int):
context.id_slot = id_slot context.id_slot = id_slot
@ -194,7 +194,7 @@ def step_start_server(context):
@step("the server is {expecting_status}") @step("the server is {expecting_status}")
@async_run_until_complete @async_run_until_complete
async def step_wait_for_the_server_to_be_started(context, expecting_status): async def step_wait_for_the_server_to_be_started(context, expecting_status: str):
match expecting_status: match expecting_status:
case 'healthy': case 'healthy':
await wait_for_health_status(context, context.base_url, 200, 'ok', await wait_for_health_status(context, context.base_url, 200, 'ok',
@ -224,7 +224,7 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
@step('all slots are {expected_slot_status_string}') @step('all slots are {expected_slot_status_string}')
@async_run_until_complete @async_run_until_complete
async def step_all_slots_status(context, expected_slot_status_string): async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
match expected_slot_status_string: match expected_slot_status_string:
case 'idle': case 'idle':
expected_slot_status = 0 expected_slot_status = 0
@ -240,7 +240,7 @@ async def step_all_slots_status(context, expected_slot_status_string):
@step('a completion request with {api_error} api error') @step('a completion request with {api_error} api error')
@async_run_until_complete @async_run_until_complete
async def step_request_completion(context, api_error): async def step_request_completion(context, api_error: Literal['raised'] | str):
expect_api_error = api_error == 'raised' expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1) seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(), completion = await request_completion(context.prompts.pop(),
@ -865,7 +865,7 @@ async def request_completion(prompt,
id_slot=None, id_slot=None,
expect_api_error=None, expect_api_error=None,
user_api_key=None, user_api_key=None,
temperature=None): temperature=None) -> int | dict[str, Any]:
if debug: if debug:
print(f"Sending completion request: {prompt}") print(f"Sending completion request: {prompt}")
origin = "my.super.domain" origin = "my.super.domain"
@ -913,7 +913,7 @@ async def oai_chat_completions(user_prompt,
enable_streaming=None, enable_streaming=None,
response_format=None, response_format=None,
user_api_key=None, user_api_key=None,
expect_api_error=None): expect_api_error=None) -> int | dict[str, Any]:
if debug: if debug:
print(f"Sending OAI Chat completions request: {user_prompt}") print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key # openai client always expects an api key
@ -1035,7 +1035,7 @@ async def oai_chat_completions(user_prompt,
return completion_response return completion_response
async def request_embedding(content, seed, base_url=None): async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding', async with session.post(f'{base_url}/embedding',
json={ json={
@ -1048,7 +1048,7 @@ async def request_embedding(content, seed, base_url=None):
async def request_oai_embeddings(input, seed, async def request_oai_embeddings(input, seed,
base_url=None, user_api_key=None, base_url=None, user_api_key=None,
model=None, async_client=False): model=None, async_client=False) -> list[list[float]]:
# openai client always expects an api_key # openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope' user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client: if async_client:
@ -1086,7 +1086,7 @@ async def request_oai_embeddings(input, seed,
input=input, input=input,
) )
return oai_embeddings.data return [e.embedding for e in oai_embeddings.data]
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):