2024-11-26 16:20:18 +01:00
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset . tinyllama2 ( )
@pytest.fixture ( scope = " module " , autouse = True )
def create_server ( ) :
global server
server = ServerPreset . tinyllama2 ( )
@pytest.mark.parametrize (
2024-12-06 11:14:32 +01:00
" model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason " ,
2024-11-26 16:20:18 +01:00
[
2024-12-06 11:14:32 +01:00
( None , " Book " , " What is the best book " , 8 , " (Suddenly)+ " , 77 , 8 , " length " ) ,
( " codellama70b " , " You are a coding assistant. " , " Write the fibonacci function in c++. " , 128 , " (Aside|she|felter|alonger)+ " , 104 , 64 , " length " ) ,
2024-11-26 16:20:18 +01:00
]
)
2024-12-06 11:14:32 +01:00
def test_chat_completion ( model , system_prompt , user_prompt , max_tokens , re_content , n_prompt , n_predicted , finish_reason ) :
2024-11-26 16:20:18 +01:00
global server
server . start ( )
res = server . make_request ( " POST " , " /chat/completions " , data = {
" model " : model ,
" max_tokens " : max_tokens ,
" messages " : [
{ " role " : " system " , " content " : system_prompt } ,
{ " role " : " user " , " content " : user_prompt } ,
] ,
} )
assert res . status_code == 200
2024-12-07 20:21:09 +01:00
assert " cmpl " in res . body [ " id " ] # make sure the completion id has the expected format
2024-12-23 12:02:44 +01:00
assert res . body [ " system_fingerprint " ] . startswith ( " b " )
2024-12-06 11:14:32 +01:00
assert res . body [ " model " ] == model if model is not None else server . model_alias
2024-11-26 16:20:18 +01:00
assert res . body [ " usage " ] [ " prompt_tokens " ] == n_prompt
assert res . body [ " usage " ] [ " completion_tokens " ] == n_predicted
choice = res . body [ " choices " ] [ 0 ]
assert " assistant " == choice [ " message " ] [ " role " ]
assert match_regex ( re_content , choice [ " message " ] [ " content " ] )
2024-12-06 11:14:32 +01:00
assert choice [ " finish_reason " ] == finish_reason
2024-11-26 16:20:18 +01:00
@pytest.mark.parametrize (
2024-12-06 11:14:32 +01:00
" system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason " ,
2024-11-26 16:20:18 +01:00
[
2024-12-06 11:14:32 +01:00
( " Book " , " What is the best book " , 8 , " (Suddenly)+ " , 77 , 8 , " length " ) ,
( " You are a coding assistant. " , " Write the fibonacci function in c++. " , 128 , " (Aside|she|felter|alonger)+ " , 104 , 64 , " length " ) ,
2024-11-26 16:20:18 +01:00
]
)
2024-12-06 11:14:32 +01:00
def test_chat_completion_stream ( system_prompt , user_prompt , max_tokens , re_content , n_prompt , n_predicted , finish_reason ) :
2024-11-26 16:20:18 +01:00
global server
2024-12-06 11:14:32 +01:00
server . model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
2024-11-26 16:20:18 +01:00
server . start ( )
res = server . make_stream_request ( " POST " , " /chat/completions " , data = {
" max_tokens " : max_tokens ,
" messages " : [
{ " role " : " system " , " content " : system_prompt } ,
{ " role " : " user " , " content " : user_prompt } ,
] ,
" stream " : True ,
} )
content = " "
2024-12-07 20:21:09 +01:00
last_cmpl_id = None
2024-11-26 16:20:18 +01:00
for data in res :
choice = data [ " choices " ] [ 0 ]
2024-12-23 12:02:44 +01:00
assert data [ " system_fingerprint " ] . startswith ( " b " )
2024-12-06 11:14:32 +01:00
assert " gpt-3.5 " in data [ " model " ] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
2024-12-07 20:21:09 +01:00
if last_cmpl_id is None :
last_cmpl_id = data [ " id " ]
assert last_cmpl_id == data [ " id " ] # make sure the completion id is the same for all events in the stream
2024-11-26 16:20:18 +01:00
if choice [ " finish_reason " ] in [ " stop " , " length " ] :
assert data [ " usage " ] [ " prompt_tokens " ] == n_prompt
assert data [ " usage " ] [ " completion_tokens " ] == n_predicted
assert " content " not in choice [ " delta " ]
assert match_regex ( re_content , content )
2024-12-06 11:14:32 +01:00
assert choice [ " finish_reason " ] == finish_reason
2024-11-26 16:20:18 +01:00
else :
assert choice [ " finish_reason " ] is None
content + = choice [ " delta " ] [ " content " ]
def test_chat_completion_with_openai_library ( ) :
global server
server . start ( )
2024-12-31 12:34:13 +01:00
client = OpenAI ( api_key = " dummy " , base_url = f " http:// { server . server_host } : { server . server_port } /v1 " )
2024-11-26 16:20:18 +01:00
res = client . chat . completions . create (
model = " gpt-3.5-turbo-instruct " ,
messages = [
{ " role " : " system " , " content " : " Book " } ,
{ " role " : " user " , " content " : " What is the best book " } ,
] ,
max_tokens = 8 ,
seed = 42 ,
temperature = 0.8 ,
)
2024-12-23 12:02:44 +01:00
assert res . system_fingerprint is not None and res . system_fingerprint . startswith ( " b " )
2024-12-06 11:14:32 +01:00
assert res . choices [ 0 ] . finish_reason == " length "
2024-11-26 16:20:18 +01:00
assert res . choices [ 0 ] . message . content is not None
assert match_regex ( " (Suddenly)+ " , res . choices [ 0 ] . message . content )
2024-12-31 15:22:01 +01:00
def test_chat_template ( ) :
global server
server . chat_template = " llama3 "
server . debug = True # to get the "__verbose" object in the response
server . start ( )
res = server . make_request ( " POST " , " /chat/completions " , data = {
" max_tokens " : 8 ,
" messages " : [
{ " role " : " system " , " content " : " Book " } ,
{ " role " : " user " , " content " : " What is the best book " } ,
]
} )
assert res . status_code == 200
assert " __verbose " in res . body
assert res . body [ " __verbose " ] [ " prompt " ] == " <s> <|start_header_id|>system<|end_header_id|> \n \n Book<|eot_id|><|start_header_id|>user<|end_header_id|> \n \n What is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n "
2024-11-26 16:20:18 +01:00
@pytest.mark.parametrize ( " response_format,n_predicted,re_content " , [
( { " type " : " json_object " , " schema " : { " const " : " 42 " } } , 6 , " \" 42 \" " ) ,
( { " type " : " json_object " , " schema " : { " items " : [ { " type " : " integer " } ] } } , 10 , " [ -3000 ] " ) ,
( { " type " : " json_object " } , 10 , " ( \\ { |John)+ " ) ,
( { " type " : " sound " } , 0 , None ) ,
# invalid response format (expected to fail)
( { " type " : " json_object " , " schema " : 123 } , 0 , None ) ,
( { " type " : " json_object " , " schema " : { " type " : 123 } } , 0 , None ) ,
( { " type " : " json_object " , " schema " : { " type " : " hiccup " } } , 0 , None ) ,
] )
def test_completion_with_response_format ( response_format : dict , n_predicted : int , re_content : str | None ) :
global server
server . start ( )
res = server . make_request ( " POST " , " /chat/completions " , data = {
" max_tokens " : n_predicted ,
" messages " : [
{ " role " : " system " , " content " : " You are a coding assistant. " } ,
{ " role " : " user " , " content " : " Write an example " } ,
] ,
" response_format " : response_format ,
} )
if re_content is not None :
assert res . status_code == 200
choice = res . body [ " choices " ] [ 0 ]
assert match_regex ( re_content , choice [ " message " ] [ " content " ] )
else :
assert res . status_code != 200
assert " error " in res . body
2024-11-29 21:48:56 +01:00
@pytest.mark.parametrize ( " messages " , [
None ,
" string " ,
[ 123 ] ,
[ { } ] ,
[ { " role " : 123 } ] ,
[ { " role " : " system " , " content " : 123 } ] ,
# [{"content": "hello"}], # TODO: should not be a valid case
[ { " role " : " system " , " content " : " test " } , { } ] ,
] )
def test_invalid_chat_completion_req ( messages ) :
global server
server . start ( )
res = server . make_request ( " POST " , " /chat/completions " , data = {
" messages " : messages ,
} )
assert res . status_code == 400 or res . status_code == 500
assert " error " in res . body
2024-12-02 14:45:54 +01:00
def test_chat_completion_with_timings_per_token ( ) :
global server
server . start ( )
res = server . make_stream_request ( " POST " , " /chat/completions " , data = {
" max_tokens " : 10 ,
" messages " : [ { " role " : " user " , " content " : " test " } ] ,
" stream " : True ,
" timings_per_token " : True ,
} )
for data in res :
assert " timings " in data
assert " prompt_per_second " in data [ " timings " ]
assert " predicted_per_second " in data [ " timings " ]
assert " predicted_n " in data [ " timings " ]
assert data [ " timings " ] [ " predicted_n " ] < = 10
2024-12-19 15:40:08 +01:00
def test_logprobs ( ) :
global server
server . start ( )
2024-12-31 12:34:13 +01:00
client = OpenAI ( api_key = " dummy " , base_url = f " http:// { server . server_host } : { server . server_port } /v1 " )
2024-12-19 15:40:08 +01:00
res = client . chat . completions . create (
model = " gpt-3.5-turbo-instruct " ,
temperature = 0.0 ,
messages = [
{ " role " : " system " , " content " : " Book " } ,
{ " role " : " user " , " content " : " What is the best book " } ,
] ,
max_tokens = 5 ,
logprobs = True ,
top_logprobs = 10 ,
)
output_text = res . choices [ 0 ] . message . content
aggregated_text = ' '
assert res . choices [ 0 ] . logprobs is not None
assert res . choices [ 0 ] . logprobs . content is not None
for token in res . choices [ 0 ] . logprobs . content :
aggregated_text + = token . token
assert token . logprob < = 0.0
assert token . bytes is not None
assert len ( token . top_logprobs ) > 0
assert aggregated_text == output_text
def test_logprobs_stream ( ) :
global server
server . start ( )
2024-12-31 12:34:13 +01:00
client = OpenAI ( api_key = " dummy " , base_url = f " http:// { server . server_host } : { server . server_port } /v1 " )
2024-12-19 15:40:08 +01:00
res = client . chat . completions . create (
model = " gpt-3.5-turbo-instruct " ,
temperature = 0.0 ,
messages = [
{ " role " : " system " , " content " : " Book " } ,
{ " role " : " user " , " content " : " What is the best book " } ,
] ,
max_tokens = 5 ,
logprobs = True ,
top_logprobs = 10 ,
stream = True ,
)
output_text = ' '
aggregated_text = ' '
for data in res :
choice = data . choices [ 0 ]
if choice . finish_reason is None :
if choice . delta . content :
output_text + = choice . delta . content
assert choice . logprobs is not None
assert choice . logprobs . content is not None
for token in choice . logprobs . content :
aggregated_text + = token . token
assert token . logprob < = 0.0
assert token . bytes is not None
assert token . top_logprobs is not None
assert len ( token . top_logprobs ) > 0
assert aggregated_text == output_text