2023-07-11 23:50:08 +02:00
import time
2023-09-16 05:11:16 +02:00
2023-07-11 23:50:08 +02:00
import tiktoken
import torch
import torch . nn . functional as F
2023-09-16 05:11:16 +02:00
import yaml
from extensions . openai . defaults import clamp , default , get_default_req_params
from extensions . openai . errors import InvalidRequestError
from extensions . openai . utils import debug_msg , end_line
2023-07-11 23:50:08 +02:00
from modules import shared
2023-09-16 05:11:16 +02:00
from modules . text_generation import decode , encode , generate_reply
from transformers import LogitsProcessor , LogitsProcessorList
2023-07-11 23:50:08 +02:00
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
class LogitsBiasProcessor ( LogitsProcessor ) :
def __init__ ( self , logit_bias = { } ) :
self . logit_bias = logit_bias
2023-07-24 16:28:12 +02:00
if self . logit_bias :
self . keys = list ( [ int ( key ) for key in self . logit_bias . keys ( ) ] )
2023-09-16 05:11:16 +02:00
values = [ self . logit_bias [ str ( key ) ] for key in self . keys ]
2023-07-24 16:28:12 +02:00
self . values = torch . tensor ( values , dtype = torch . float , device = shared . model . device )
debug_msg ( f " { self } ) " )
2023-07-11 23:50:08 +02:00
def __call__ ( self , input_ids : torch . LongTensor , logits : torch . FloatTensor ) - > torch . FloatTensor :
if self . logit_bias :
2023-07-24 16:28:12 +02:00
debug_msg ( logits [ 0 , self . keys ] , " + " , self . values )
logits [ 0 , self . keys ] + = self . values
debug_msg ( " --> " , logits [ 0 , self . keys ] )
debug_msg ( " max/min " , float ( torch . max ( logits [ 0 ] ) ) , float ( torch . min ( logits [ 0 ] ) ) )
2023-07-11 23:50:08 +02:00
return logits
2023-07-24 16:28:12 +02:00
def __repr__ ( self ) :
return f " < { self . __class__ . __name__ } (logit_bias= { self . logit_bias } )> "
2023-07-11 23:50:08 +02:00
2023-09-16 05:11:16 +02:00
2023-07-11 23:50:08 +02:00
class LogprobProcessor ( LogitsProcessor ) :
def __init__ ( self , logprobs = None ) :
self . logprobs = logprobs
self . token_alternatives = { }
def __call__ ( self , input_ids : torch . LongTensor , logits : torch . FloatTensor ) - > torch . FloatTensor :
2023-07-12 20:33:25 +02:00
if self . logprobs is not None : # 0-5
2023-07-11 23:50:08 +02:00
log_e_probabilities = F . log_softmax ( logits , dim = 1 )
2023-09-16 05:11:16 +02:00
top_values , top_indices = torch . topk ( log_e_probabilities , k = self . logprobs + 1 )
top_tokens = [ decode ( tok ) for tok in top_indices [ 0 ] ]
top_probs = [ float ( x ) for x in top_values [ 0 ] ]
2023-07-24 16:28:12 +02:00
self . token_alternatives = dict ( zip ( top_tokens , top_probs ) )
2023-08-02 03:26:00 +02:00
debug_msg ( repr ( self ) )
2023-07-11 23:50:08 +02:00
return logits
2023-07-24 16:28:12 +02:00
def __repr__ ( self ) :
return f " < { self . __class__ . __name__ } (logprobs= { self . logprobs } , token_alternatives= { self . token_alternatives } )> "
2023-07-11 23:50:08 +02:00
def convert_logprobs_to_tiktoken ( model , logprobs ) :
2023-09-16 05:11:16 +02:00
# more problems than it's worth.
# try:
# encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
2023-08-02 03:26:00 +02:00
return logprobs
2023-07-11 23:50:08 +02:00
def marshal_common_params ( body ) :
# Request Parameters
# Try to use openai defaults or map them to something with the same intent
req_params = get_default_req_params ( )
# Common request parameters
req_params [ ' truncation_length ' ] = shared . settings [ ' truncation_length ' ]
req_params [ ' add_bos_token ' ] = shared . settings . get ( ' add_bos_token ' , req_params [ ' add_bos_token ' ] )
req_params [ ' seed ' ] = shared . settings . get ( ' seed ' , req_params [ ' seed ' ] )
req_params [ ' custom_stopping_strings ' ] = shared . settings [ ' custom_stopping_strings ' ]
# OpenAI API Parameters
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
req_params [ ' requested_model ' ] = body . get ( ' model ' , shared . model_name )
2023-07-12 20:33:25 +02:00
2023-07-11 23:50:08 +02:00
req_params [ ' suffix ' ] = default ( body , ' suffix ' , req_params [ ' suffix ' ] )
2023-07-24 16:28:12 +02:00
req_params [ ' temperature ' ] = clamp ( default ( body , ' temperature ' , req_params [ ' temperature ' ] ) , 0.01 , 1.99 ) # fixup absolute 0.0/2.0
req_params [ ' top_p ' ] = clamp ( default ( body , ' top_p ' , req_params [ ' top_p ' ] ) , 0.01 , 1.0 )
2023-07-11 23:50:08 +02:00
n = default ( body , ' n ' , 1 )
if n != 1 :
raise InvalidRequestError ( message = " Only n = 1 is supported. " , param = ' n ' )
if ' stop ' in body : # str or array, max len 4 (ignored)
if isinstance ( body [ ' stop ' ] , str ) :
2023-07-12 20:33:25 +02:00
req_params [ ' stopping_strings ' ] = [ body [ ' stop ' ] ] # non-standard parameter
2023-07-11 23:50:08 +02:00
elif isinstance ( body [ ' stop ' ] , list ) :
req_params [ ' stopping_strings ' ] = body [ ' stop ' ]
# presence_penalty - ignored
# frequency_penalty - ignored
2023-07-24 16:28:12 +02:00
# pass through unofficial params
req_params [ ' repetition_penalty ' ] = default ( body , ' repetition_penalty ' , req_params [ ' repetition_penalty ' ] )
req_params [ ' encoder_repetition_penalty ' ] = default ( body , ' encoder_repetition_penalty ' , req_params [ ' encoder_repetition_penalty ' ] )
2023-07-11 23:50:08 +02:00
# user - ignored
logits_processor = [ ]
logit_bias = body . get ( ' logit_bias ' , None )
2023-07-12 20:33:25 +02:00
if logit_bias : # {str: float, ...}
2023-07-11 23:50:08 +02:00
# XXX convert tokens from tiktoken based on requested model
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
try :
encoder = tiktoken . encoding_for_model ( req_params [ ' requested_model ' ] )
new_logit_bias = { }
for logit , bias in logit_bias . items ( ) :
2023-07-24 16:28:12 +02:00
for x in encode ( encoder . decode ( [ int ( logit ) ] ) , add_special_tokens = False ) [ 0 ] :
2023-09-16 05:11:16 +02:00
if int ( x ) in [ 0 , 1 , 2 , 29871 ] : # XXX LLAMA tokens
2023-07-24 16:28:12 +02:00
continue
2023-07-11 23:50:08 +02:00
new_logit_bias [ str ( int ( x ) ) ] = bias
2023-07-24 16:28:12 +02:00
debug_msg ( ' logit_bias_map ' , logit_bias , ' -> ' , new_logit_bias )
2023-07-11 23:50:08 +02:00
logit_bias = new_logit_bias
except KeyError :
2023-07-12 20:33:25 +02:00
pass # assume native tokens if we can't find the tokenizer
2023-07-11 23:50:08 +02:00
logits_processor = [ LogitsBiasProcessor ( logit_bias ) ]
2023-07-12 20:33:25 +02:00
logprobs = None # coming to chat eventually
2023-07-11 23:50:08 +02:00
if ' logprobs ' in body :
2023-07-12 20:33:25 +02:00
logprobs = default ( body , ' logprobs ' , 0 ) # maybe cap at topk? don't clamp 0-5.
2023-07-11 23:50:08 +02:00
req_params [ ' logprob_proc ' ] = LogprobProcessor ( logprobs )
logits_processor . extend ( [ req_params [ ' logprob_proc ' ] ] )
else :
logprobs = None
2023-07-12 20:33:25 +02:00
if logits_processor : # requires logits_processor support
2023-07-11 23:50:08 +02:00
req_params [ ' logits_processor ' ] = LogitsProcessorList ( logits_processor )
return req_params
def messages_to_prompt ( body : dict , req_params : dict , max_tokens ) :
# functions
2023-07-12 20:33:25 +02:00
if body . get ( ' functions ' , [ ] ) : # chat only
2023-07-11 23:50:08 +02:00
raise InvalidRequestError ( message = " functions is not supported. " , param = ' functions ' )
2023-07-12 20:33:25 +02:00
if body . get ( ' function_call ' , ' ' ) : # chat only, 'none', 'auto', {'name': 'func'}
2023-07-11 23:50:08 +02:00
raise InvalidRequestError ( message = " function_call is not supported. " , param = ' function_call ' )
2023-09-16 05:11:16 +02:00
if ' messages ' not in body :
2023-07-11 23:50:08 +02:00
raise InvalidRequestError ( message = " messages is required " , param = ' messages ' )
2023-07-12 20:33:25 +02:00
2023-07-11 23:50:08 +02:00
messages = body [ ' messages ' ]
role_formats = {
2023-07-24 16:28:12 +02:00
' user ' : ' User: {message} \n ' ,
' assistant ' : ' Assistant: {message} \n ' ,
2023-07-11 23:50:08 +02:00
' system ' : ' {message} ' ,
2023-07-24 16:28:12 +02:00
' context ' : ' You are a helpful assistant. Answer as concisely as possible. \n User: I want your assistance. \n Assistant: Sure! What can I do for you? ' ,
' prompt ' : ' Assistant: ' ,
2023-07-11 23:50:08 +02:00
}
2023-09-16 05:11:16 +02:00
if ' stopping_strings ' not in req_params :
2023-07-11 23:50:08 +02:00
req_params [ ' stopping_strings ' ] = [ ]
# Instruct models can be much better
if shared . settings [ ' instruction_template ' ] :
try :
2023-08-07 02:50:07 +02:00
instruct = yaml . safe_load ( open ( f " instruction-templates/ { shared . settings [ ' instruction_template ' ] } .yaml " , ' r ' ) )
2023-07-11 23:50:08 +02:00
template = instruct [ ' turn_template ' ]
system_message_template = " {message} "
2023-09-16 05:11:16 +02:00
system_message_default = instruct . get ( ' context ' , ' ' ) # can be missing
2023-07-12 20:33:25 +02:00
bot_start = template . find ( ' <|bot|> ' ) # So far, 100% of instruction templates have this token
2023-07-24 16:28:12 +02:00
user_message_template = template [ : bot_start ] . replace ( ' <|user-message|> ' , ' {message} ' ) . replace ( ' <|user|> ' , instruct . get ( ' user ' , ' ' ) )
bot_message_template = template [ bot_start : ] . replace ( ' <|bot-message|> ' , ' {message} ' ) . replace ( ' <|bot|> ' , instruct . get ( ' bot ' , ' ' ) )
2023-07-11 23:50:08 +02:00
bot_prompt = bot_message_template [ : bot_message_template . find ( ' {message} ' ) ] . rstrip ( ' ' )
2023-07-12 20:33:25 +02:00
2023-07-11 23:50:08 +02:00
role_formats = {
' user ' : user_message_template ,
' assistant ' : bot_message_template ,
' system ' : system_message_template ,
' context ' : system_message_default ,
' prompt ' : bot_prompt ,
}
if ' Alpaca ' in shared . settings [ ' instruction_template ' ] :
req_params [ ' stopping_strings ' ] . extend ( [ ' \n ### ' ] )
2023-07-12 20:33:25 +02:00
elif instruct [ ' user ' ] : # WizardLM and some others have no user prompt.
2023-07-11 23:50:08 +02:00
req_params [ ' stopping_strings ' ] . extend ( [ ' \n ' + instruct [ ' user ' ] , instruct [ ' user ' ] ] )
debug_msg ( f " Loaded instruction role format: { shared . settings [ ' instruction_template ' ] } " )
except Exception as e :
2023-07-24 16:28:12 +02:00
req_params [ ' stopping_strings ' ] . extend ( [ ' \n User: ' , ' User: ' ] ) # XXX User: prompt here also
2023-07-11 23:50:08 +02:00
2023-08-07 02:50:07 +02:00
print ( f " Exception: When loading instruction-templates/ { shared . settings [ ' instruction_template ' ] } .yaml: { repr ( e ) } " )
2023-07-11 23:50:08 +02:00
print ( " Warning: Loaded default instruction-following template for model. " )
else :
2023-07-24 16:28:12 +02:00
req_params [ ' stopping_strings ' ] . extend ( [ ' \n User: ' , ' User: ' ] ) # XXX User: prompt here also
2023-07-11 23:50:08 +02:00
print ( " Warning: Loaded default instruction-following template for model. " )
system_msgs = [ ]
chat_msgs = [ ]
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
context_msg = role_formats [ ' system ' ] . format ( message = role_formats [ ' context ' ] ) if role_formats [ ' context ' ] else ' '
context_msg = end_line ( context_msg )
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
if ' prompt ' in body :
context_msg = end_line ( role_formats [ ' system ' ] . format ( message = body [ ' prompt ' ] ) ) + context_msg
for m in messages :
2023-07-24 16:28:12 +02:00
if ' role ' not in m :
raise InvalidRequestError ( message = " messages: missing role " , param = ' messages ' )
if ' content ' not in m :
raise InvalidRequestError ( message = " messages: missing content " , param = ' messages ' )
2023-09-16 05:11:16 +02:00
2023-07-11 23:50:08 +02:00
role = m [ ' role ' ]
content = m [ ' content ' ]
# name = m.get('name', None)
# function_call = m.get('function_call', None) # user name or function name with output in content
msg = role_formats [ role ] . format ( message = content )
if role == ' system ' :
system_msgs . extend ( [ msg ] )
elif role == ' function ' :
raise InvalidRequestError ( message = " role: function is not supported. " , param = ' messages ' )
else :
chat_msgs . extend ( [ msg ] )
system_msg = ' \n ' . join ( system_msgs )
system_msg = end_line ( system_msg )
prompt = system_msg + context_msg + ' ' . join ( chat_msgs ) + role_formats [ ' prompt ' ]
token_count = len ( encode ( prompt ) [ 0 ] )
if token_count > = req_params [ ' truncation_length ' ] :
err_msg = f " This model maximum context length is { req_params [ ' truncation_length ' ] } tokens. However, your messages resulted in over { token_count } tokens. "
2023-07-24 16:28:12 +02:00
raise InvalidRequestError ( message = err_msg , param = ' messages ' )
2023-07-11 23:50:08 +02:00
if max_tokens > 0 and token_count + max_tokens > req_params [ ' truncation_length ' ] :
err_msg = f " This model maximum context length is { req_params [ ' truncation_length ' ] } tokens. However, your messages resulted in over { token_count } tokens and max_tokens is { max_tokens } . "
print ( f " Warning: $ { err_msg } " )
2023-07-24 16:28:12 +02:00
# raise InvalidRequestError(message=err_msg, params='max_tokens')
2023-07-11 23:50:08 +02:00
return prompt , token_count
2023-07-12 20:33:25 +02:00
def chat_completions ( body : dict , is_legacy : bool = False ) - > dict :
2023-07-11 23:50:08 +02:00
# Chat Completions
object_type = ' chat.completions '
created_time = int ( time . time ( ) )
2023-07-12 20:33:25 +02:00
cmpl_id = " chatcmpl- %d " % ( int ( time . time ( ) * 1000000000 ) )
2023-07-11 23:50:08 +02:00
resp_list = ' data ' if is_legacy else ' choices '
# common params
req_params = marshal_common_params ( body )
req_params [ ' stream ' ] = False
requested_model = req_params . pop ( ' requested_model ' )
logprob_proc = req_params . pop ( ' logprob_proc ' , None )
2023-07-12 20:33:25 +02:00
req_params [ ' top_k ' ] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
2023-07-11 23:50:08 +02:00
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = ' length ' if is_legacy else ' max_tokens '
if max_tokens_str in body :
max_tokens = default ( body , max_tokens_str , req_params [ ' truncation_length ' ] )
req_params [ ' max_new_tokens ' ] = max_tokens
else :
req_params [ ' max_new_tokens ' ] = req_params [ ' truncation_length ' ]
# format the prompt from messages
2023-08-02 03:26:00 +02:00
prompt , token_count = messages_to_prompt ( body , req_params , max_tokens ) # updates req_params['stopping_strings']
2023-07-11 23:50:08 +02:00
2023-07-24 16:28:12 +02:00
# set real max, avoid deeper errors
if req_params [ ' max_new_tokens ' ] + token_count > = req_params [ ' truncation_length ' ] :
req_params [ ' max_new_tokens ' ] = req_params [ ' truncation_length ' ] - token_count
2023-08-02 03:26:00 +02:00
stopping_strings = req_params . pop ( ' stopping_strings ' , [ ] )
2023-07-11 23:50:08 +02:00
# generate reply #######################################
debug_msg ( { ' prompt ' : prompt , ' req_params ' : req_params } )
generator = generate_reply ( prompt , req_params , stopping_strings = stopping_strings , is_chat = False )
answer = ' '
for a in generator :
answer = a
# strip extra leading space off new generated content
if answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
completion_token_count = len ( encode ( answer ) [ 0 ] )
stop_reason = " stop "
2023-07-24 16:28:12 +02:00
if token_count + completion_token_count > = req_params [ ' truncation_length ' ] or completion_token_count > = req_params [ ' max_new_tokens ' ] :
2023-07-11 23:50:08 +02:00
stop_reason = " length "
resp = {
" id " : cmpl_id ,
" object " : object_type ,
" created " : created_time ,
" model " : shared . model_name , # TODO: add Lora info?
resp_list : [ {
" index " : 0 ,
" finish_reason " : stop_reason ,
" message " : { " role " : " assistant " , " content " : answer }
} ] ,
" usage " : {
" prompt_tokens " : token_count ,
" completion_tokens " : completion_token_count ,
" total_tokens " : token_count + completion_token_count
}
}
2023-07-12 20:33:25 +02:00
if logprob_proc : # not official for chat yet
2023-07-11 23:50:08 +02:00
top_logprobs = convert_logprobs_to_tiktoken ( model = requested_model , logprobs = logprob_proc . token_alternatives )
2023-07-12 20:33:25 +02:00
resp [ resp_list ] [ 0 ] [ " logprobs " ] = { ' top_logprobs ' : [ top_logprobs ] }
2023-07-11 23:50:08 +02:00
# else:
# resp[resp_list][0]["logprobs"] = None
return resp
# generator
2023-07-12 20:33:25 +02:00
def stream_chat_completions ( body : dict , is_legacy : bool = False ) :
2023-07-11 23:50:08 +02:00
# Chat Completions
stream_object_type = ' chat.completions.chunk '
created_time = int ( time . time ( ) )
2023-07-12 20:33:25 +02:00
cmpl_id = " chatcmpl- %d " % ( int ( time . time ( ) * 1000000000 ) )
2023-07-11 23:50:08 +02:00
resp_list = ' data ' if is_legacy else ' choices '
# common params
req_params = marshal_common_params ( body )
req_params [ ' stream ' ] = True
requested_model = req_params . pop ( ' requested_model ' )
logprob_proc = req_params . pop ( ' logprob_proc ' , None )
2023-07-12 20:33:25 +02:00
req_params [ ' top_k ' ] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
2023-07-11 23:50:08 +02:00
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = ' length ' if is_legacy else ' max_tokens '
if max_tokens_str in body :
max_tokens = default ( body , max_tokens_str , req_params [ ' truncation_length ' ] )
req_params [ ' max_new_tokens ' ] = max_tokens
else :
req_params [ ' max_new_tokens ' ] = req_params [ ' truncation_length ' ]
# format the prompt from messages
2023-08-02 03:26:00 +02:00
prompt , token_count = messages_to_prompt ( body , req_params , max_tokens ) # updates req_params['stopping_strings']
2023-07-11 23:50:08 +02:00
2023-07-24 16:28:12 +02:00
# set real max, avoid deeper errors
if req_params [ ' max_new_tokens ' ] + token_count > = req_params [ ' truncation_length ' ] :
req_params [ ' max_new_tokens ' ] = req_params [ ' truncation_length ' ] - token_count
2023-07-11 23:50:08 +02:00
def chat_streaming_chunk ( content ) :
# begin streaming
chunk = {
" id " : cmpl_id ,
" object " : stream_object_type ,
" created " : created_time ,
" model " : shared . model_name ,
resp_list : [ {
" index " : 0 ,
" finish_reason " : None ,
# So yeah... do both methods? delta and messages.
" message " : { ' role ' : ' assistant ' , ' content ' : content } ,
" delta " : { ' role ' : ' assistant ' , ' content ' : content } ,
} ] ,
}
2023-07-12 20:33:25 +02:00
if logprob_proc : # not official for chat yet
2023-07-11 23:50:08 +02:00
top_logprobs = convert_logprobs_to_tiktoken ( model = requested_model , logprobs = logprob_proc . token_alternatives )
2023-07-12 20:33:25 +02:00
chunk [ resp_list ] [ 0 ] [ " logprobs " ] = { ' top_logprobs ' : [ top_logprobs ] }
# else:
2023-07-11 23:50:08 +02:00
# chunk[resp_list][0]["logprobs"] = None
return chunk
yield chat_streaming_chunk ( ' ' )
# generate reply #######################################
debug_msg ( { ' prompt ' : prompt , ' req_params ' : req_params } )
2023-07-12 20:33:25 +02:00
2023-07-11 23:50:08 +02:00
stopping_strings = req_params . pop ( ' stopping_strings ' , [ ] )
generator = generate_reply ( prompt , req_params , stopping_strings = stopping_strings , is_chat = False )
answer = ' '
seen_content = ' '
completion_token_count = 0
for a in generator :
answer = a
len_seen = len ( seen_content )
new_content = answer [ len_seen : ]
if not new_content or chr ( 0xfffd ) in new_content : # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content [ 0 ] == ' ' :
new_content = new_content [ 1 : ]
chunk = chat_streaming_chunk ( new_content )
2023-07-12 20:33:25 +02:00
yield chunk
2023-07-11 23:50:08 +02:00
2023-07-24 16:28:12 +02:00
# to get the correct token_count, strip leading space if present
if answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
completion_token_count = len ( encode ( answer ) [ 0 ] )
2023-07-11 23:50:08 +02:00
stop_reason = " stop "
2023-07-24 16:28:12 +02:00
if token_count + completion_token_count > = req_params [ ' truncation_length ' ] or completion_token_count > = req_params [ ' max_new_tokens ' ] :
2023-07-11 23:50:08 +02:00
stop_reason = " length "
chunk = chat_streaming_chunk ( ' ' )
chunk [ resp_list ] [ 0 ] [ ' finish_reason ' ] = stop_reason
chunk [ ' usage ' ] = {
" prompt_tokens " : token_count ,
" completion_tokens " : completion_token_count ,
" total_tokens " : token_count + completion_token_count
}
yield chunk
2023-07-12 20:33:25 +02:00
def completions ( body : dict , is_legacy : bool = False ) :
2023-07-11 23:50:08 +02:00
# Legacy
# Text Completions
object_type = ' text_completion '
created_time = int ( time . time ( ) )
2023-07-12 20:33:25 +02:00
cmpl_id = " conv- %d " % ( int ( time . time ( ) * 1000000000 ) )
2023-07-11 23:50:08 +02:00
resp_list = ' data ' if is_legacy else ' choices '
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = ' context ' if is_legacy else ' prompt '
2023-09-16 05:11:16 +02:00
if prompt_str not in body :
2023-07-11 23:50:08 +02:00
raise InvalidRequestError ( " Missing required input " , param = prompt_str )
2023-08-02 03:26:00 +02:00
prompt_arg = body [ prompt_str ]
if isinstance ( prompt_arg , str ) or ( isinstance ( prompt_arg , list ) and isinstance ( prompt_arg [ 0 ] , int ) ) :
prompt_arg = [ prompt_arg ]
2023-07-11 23:50:08 +02:00
# common params
req_params = marshal_common_params ( body )
req_params [ ' stream ' ] = False
max_tokens_str = ' length ' if is_legacy else ' max_tokens '
max_tokens = default ( body , max_tokens_str , req_params [ ' max_new_tokens ' ] )
req_params [ ' max_new_tokens ' ] = max_tokens
requested_model = req_params . pop ( ' requested_model ' )
logprob_proc = req_params . pop ( ' logprob_proc ' , None )
2023-08-02 03:26:00 +02:00
stopping_strings = req_params . pop ( ' stopping_strings ' , [ ] )
2023-09-16 05:11:16 +02:00
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
2023-08-02 03:26:00 +02:00
req_params [ ' echo ' ] = default ( body , ' echo ' , req_params [ ' echo ' ] )
req_params [ ' top_k ' ] = default ( body , ' best_of ' , req_params [ ' top_k ' ] )
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
resp_list_data = [ ]
total_completion_token_count = 0
total_prompt_token_count = 0
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
for idx , prompt in enumerate ( prompt_arg , start = 0 ) :
if isinstance ( prompt [ 0 ] , int ) :
# token lists
if requested_model == shared . model_name :
prompt = decode ( prompt ) [ 0 ]
else :
try :
encoder = tiktoken . encoding_for_model ( requested_model )
prompt = encoder . decode ( prompt )
except KeyError :
prompt = decode ( prompt ) [ 0 ]
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
token_count = len ( encode ( prompt ) [ 0 ] )
total_prompt_token_count + = token_count
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
if token_count + max_tokens > req_params [ ' truncation_length ' ] :
err_msg = f " The token count of your prompt ( { token_count } ) plus max_tokens ( { max_tokens } ) cannot exceed the model ' s context length ( { req_params [ ' truncation_length ' ] } ). "
# print(f"Warning: ${err_msg}")
raise InvalidRequestError ( message = err_msg , param = max_tokens_str )
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
# generate reply #######################################
debug_msg ( { ' prompt ' : prompt , ' req_params ' : req_params } )
generator = generate_reply ( prompt , req_params , stopping_strings = stopping_strings , is_chat = False )
answer = ' '
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
for a in generator :
answer = a
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
# strip extra leading space off new generated content
if answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
2023-07-11 23:50:08 +02:00
2023-08-02 03:26:00 +02:00
completion_token_count = len ( encode ( answer ) [ 0 ] )
total_completion_token_count + = completion_token_count
stop_reason = " stop "
if token_count + completion_token_count > = req_params [ ' truncation_length ' ] or completion_token_count > = max_tokens :
stop_reason = " length "
respi = {
" index " : idx ,
" finish_reason " : stop_reason ,
" text " : answer ,
" logprobs " : { ' top_logprobs ' : [ logprob_proc . token_alternatives ] } if logprob_proc else None ,
}
resp_list_data . extend ( [ respi ] )
2023-07-11 23:50:08 +02:00
resp = {
" id " : cmpl_id ,
" object " : object_type ,
" created " : created_time ,
" model " : shared . model_name , # TODO: add Lora info?
2023-08-02 03:26:00 +02:00
resp_list : resp_list_data ,
2023-07-11 23:50:08 +02:00
" usage " : {
2023-08-02 03:26:00 +02:00
" prompt_tokens " : total_prompt_token_count ,
" completion_tokens " : total_completion_token_count ,
" total_tokens " : total_prompt_token_count + total_completion_token_count
2023-07-11 23:50:08 +02:00
}
}
return resp
# generator
2023-07-12 20:33:25 +02:00
def stream_completions ( body : dict , is_legacy : bool = False ) :
2023-07-11 23:50:08 +02:00
# Legacy
# Text Completions
2023-07-12 20:33:25 +02:00
# object_type = 'text_completion'
2023-07-11 23:50:08 +02:00
stream_object_type = ' text_completion.chunk '
created_time = int ( time . time ( ) )
2023-07-12 20:33:25 +02:00
cmpl_id = " conv- %d " % ( int ( time . time ( ) * 1000000000 ) )
2023-07-11 23:50:08 +02:00
resp_list = ' data ' if is_legacy else ' choices '
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = ' context ' if is_legacy else ' prompt '
2023-09-16 05:11:16 +02:00
if prompt_str not in body :
2023-07-11 23:50:08 +02:00
raise InvalidRequestError ( " Missing required input " , param = prompt_str )
prompt = body [ prompt_str ]
2023-09-16 05:11:16 +02:00
req_params = marshal_common_params ( body )
requested_model = req_params . pop ( ' requested_model ' )
2023-07-11 23:50:08 +02:00
if isinstance ( prompt , list ) :
if prompt and isinstance ( prompt [ 0 ] , int ) :
try :
encoder = tiktoken . encoding_for_model ( requested_model )
2023-07-24 16:28:12 +02:00
prompt = encoder . decode ( prompt )
2023-07-11 23:50:08 +02:00
except KeyError :
prompt = decode ( prompt ) [ 0 ]
else :
raise InvalidRequestError ( message = " API Batched generation not yet supported. " , param = prompt_str )
# common params
req_params [ ' stream ' ] = True
max_tokens_str = ' length ' if is_legacy else ' max_tokens '
max_tokens = default ( body , max_tokens_str , req_params [ ' max_new_tokens ' ] )
req_params [ ' max_new_tokens ' ] = max_tokens
logprob_proc = req_params . pop ( ' logprob_proc ' , None )
2023-08-02 03:26:00 +02:00
stopping_strings = req_params . pop ( ' stopping_strings ' , [ ] )
2023-09-16 05:11:16 +02:00
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
2023-08-02 03:26:00 +02:00
req_params [ ' echo ' ] = default ( body , ' echo ' , req_params [ ' echo ' ] )
req_params [ ' top_k ' ] = default ( body , ' best_of ' , req_params [ ' top_k ' ] )
2023-07-11 23:50:08 +02:00
token_count = len ( encode ( prompt ) [ 0 ] )
if token_count + max_tokens > req_params [ ' truncation_length ' ] :
err_msg = f " The token count of your prompt ( { token_count } ) plus max_tokens ( { max_tokens } ) cannot exceed the model ' s context length ( { req_params [ ' truncation_length ' ] } ). "
2023-07-12 20:33:25 +02:00
# print(f"Warning: ${err_msg}")
2023-07-11 23:50:08 +02:00
raise InvalidRequestError ( message = err_msg , param = max_tokens_str )
def text_streaming_chunk ( content ) :
# begin streaming
chunk = {
" id " : cmpl_id ,
" object " : stream_object_type ,
" created " : created_time ,
" model " : shared . model_name ,
resp_list : [ {
" index " : 0 ,
" finish_reason " : None ,
" text " : content ,
2023-08-02 03:26:00 +02:00
" logprobs " : { ' top_logprobs ' : [ logprob_proc . token_alternatives ] } if logprob_proc else None ,
2023-07-11 23:50:08 +02:00
} ] ,
}
return chunk
yield text_streaming_chunk ( ' ' )
# generate reply #######################################
debug_msg ( { ' prompt ' : prompt , ' req_params ' : req_params } )
generator = generate_reply ( prompt , req_params , stopping_strings = stopping_strings , is_chat = False )
answer = ' '
seen_content = ' '
completion_token_count = 0
for a in generator :
answer = a
len_seen = len ( seen_content )
new_content = answer [ len_seen : ]
if not new_content or chr ( 0xfffd ) in new_content : # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content [ 0 ] == ' ' :
new_content = new_content [ 1 : ]
chunk = text_streaming_chunk ( new_content )
yield chunk
2023-07-24 16:28:12 +02:00
# to get the correct count, we strip the leading space if present
if answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
completion_token_count = len ( encode ( answer ) [ 0 ] )
2023-07-11 23:50:08 +02:00
stop_reason = " stop "
if token_count + completion_token_count > = req_params [ ' truncation_length ' ] or completion_token_count > = max_tokens :
stop_reason = " length "
chunk = text_streaming_chunk ( ' ' )
chunk [ resp_list ] [ 0 ] [ " finish_reason " ] = stop_reason
chunk [ " usage " ] = {
" prompt_tokens " : token_count ,
" completion_tokens " : completion_token_count ,
" total_tokens " : token_count + completion_token_count
}
yield chunk