2023-05-09 03:31:34 +02:00
import base64
2023-05-03 04:05:38 +02:00
import json
import os
import time
2023-05-11 16:06:39 +02:00
import requests
import yaml
2023-06-18 00:15:24 +02:00
import numpy as np
2023-05-03 03:49:53 +02:00
from http . server import BaseHTTPRequestHandler , ThreadingHTTPServer
from threading import Thread
2023-05-31 03:13:37 +02:00
from modules . utils import get_available_models
2023-06-18 00:15:24 +02:00
from modules . models import load_model , unload_model
from modules . models_settings import ( get_model_settings_from_yamls ,
update_model_parameters )
2023-05-10 03:49:39 +02:00
2023-05-03 03:49:53 +02:00
from modules import shared
from modules . text_generation import encode , generate_reply
params = {
2023-05-03 14:51:49 +02:00
' port ' : int ( os . environ . get ( ' OPENEDAI_PORT ' ) ) if ' OPENEDAI_PORT ' in os . environ else 5001 ,
2023-05-03 03:49:53 +02:00
}
debug = True if ' OPENEDAI_DEBUG ' in os . environ else False
2023-05-31 02:52:33 +02:00
# Slightly different defaults for OpenAI's API
2023-06-06 06:43:04 +02:00
# Data type is important, Ex. use 0.0 for a float 0
2023-05-31 02:52:33 +02:00
default_req_params = {
' max_new_tokens ' : 200 ,
' temperature ' : 1.0 ,
' top_p ' : 1.0 ,
' top_k ' : 1 ,
' repetition_penalty ' : 1.18 ,
2023-06-29 18:40:13 +02:00
' repetition_penalty_range ' : 0 ,
2023-05-31 02:52:33 +02:00
' encoder_repetition_penalty ' : 1.0 ,
' suffix ' : None ,
' stream ' : False ,
' echo ' : False ,
' seed ' : - 1 ,
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
' truncation_length ' : 2048 ,
' add_bos_token ' : True ,
' do_sample ' : True ,
' typical_p ' : 1.0 ,
2023-06-18 00:15:24 +02:00
' epsilon_cutoff ' : 0.0 , # In units of 1e-4
' eta_cutoff ' : 0.0 , # In units of 1e-4
2023-05-31 02:52:33 +02:00
' tfs ' : 1.0 ,
' top_a ' : 0.0 ,
' min_length ' : 0 ,
' no_repeat_ngram_size ' : 0 ,
' num_beams ' : 1 ,
' penalty_alpha ' : 0.0 ,
2023-06-06 06:43:04 +02:00
' length_penalty ' : 1.0 ,
2023-05-31 02:52:33 +02:00
' early_stopping ' : False ,
' mirostat_mode ' : 0 ,
2023-06-06 06:43:04 +02:00
' mirostat_tau ' : 5.0 ,
2023-05-31 02:52:33 +02:00
' mirostat_eta ' : 0.1 ,
' ban_eos_token ' : False ,
' skip_special_tokens ' : True ,
2023-06-25 03:50:04 +02:00
' custom_stopping_strings ' : ' ' ,
2023-05-31 02:52:33 +02:00
}
2023-05-03 03:49:53 +02:00
# Optional, install the module and download the model to enable
# v1/embeddings
try :
from sentence_transformers import SentenceTransformer
except ImportError :
pass
st_model = os . environ [ " OPENEDAI_EMBEDDING_MODEL " ] if " OPENEDAI_EMBEDDING_MODEL " in os . environ else " all-mpnet-base-v2 "
embedding_model = None
# little helper to get defaults if arg is present but None and should be the same type as default.
def default ( dic , key , default ) :
val = dic . get ( key , default )
if type ( val ) != type ( default ) :
# maybe it's just something like 1 instead of 1.0
try :
v = type ( default ) ( val )
2023-05-03 04:05:38 +02:00
if type ( val ) ( v ) == val : # if it's the same value passed in, it's ok.
2023-05-03 03:49:53 +02:00
return v
except :
pass
val = default
return val
2023-05-03 04:05:38 +02:00
2023-05-03 03:49:53 +02:00
def clamp ( value , minvalue , maxvalue ) :
return max ( minvalue , min ( value , maxvalue ) )
2023-05-09 03:31:34 +02:00
def float_list_to_base64 ( float_list ) :
# Convert the list to a float32 array that the OpenAPI client expects
float_array = np . array ( float_list , dtype = " float32 " )
# Get raw bytes
bytes_array = float_array . tobytes ( )
# Encode bytes into base64
encoded_bytes = base64 . b64encode ( bytes_array )
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes . decode ( ' ascii ' )
return ascii_string
2023-05-10 03:49:39 +02:00
2023-05-03 03:49:53 +02:00
class Handler ( BaseHTTPRequestHandler ) :
2023-05-31 02:54:24 +02:00
def send_access_control_headers ( self ) :
self . send_header ( " Access-Control-Allow-Origin " , " * " )
self . send_header ( " Access-Control-Allow-Credentials " , " true " )
self . send_header (
" Access-Control-Allow-Methods " ,
" GET,HEAD,OPTIONS,POST,PUT "
)
self . send_header (
" Access-Control-Allow-Headers " ,
" Origin, Accept, X-Requested-With, Content-Type, "
" Access-Control-Request-Method, Access-Control-Request-Headers, "
" Authorization "
2023-06-06 06:43:04 +02:00
)
def openai_error ( self , message , code = 500 , error_type = ' APIError ' , param = ' ' , internal_message = ' ' ) :
self . send_response ( code )
self . send_access_control_headers ( )
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
error_resp = {
' error ' : {
' message ' : message ,
' code ' : code ,
' type ' : error_type ,
' param ' : param ,
}
}
if internal_message :
error_resp [ ' internal_message ' ] = internal_message
response = json . dumps ( error_resp )
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-05-31 02:54:24 +02:00
def do_OPTIONS ( self ) :
self . send_response ( 200 )
self . send_access_control_headers ( )
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
self . wfile . write ( " OK " . encode ( ' utf-8 ' ) )
2023-05-03 03:49:53 +02:00
def do_GET ( self ) :
2023-06-18 00:15:24 +02:00
if self . path . startswith ( ' /v1/engines ' ) or self . path . startswith ( ' /v1/models ' ) :
current_model_list = [ shared . model_name ] # The real chat/completions model, maybe "None"
2023-06-06 06:43:04 +02:00
embeddings_model_list = [ st_model ] if embedding_model else [ ] # The real sentence transformer embeddings model
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
' gpt-3.5-turbo ' , # /v1/chat/completions
' text-curie-001 ' , # /v1/completions, 2k context
' text-davinci-002 ' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
]
2023-06-18 00:15:24 +02:00
is_legacy = ' engines ' in self . path
is_list = self . path in [ ' /v1/engines ' , ' /v1/models ' ]
resp = ' '
if is_legacy and not is_list : # load model
model_name = self . path [ self . path . find ( ' /v1/engines/ ' ) + len ( ' /v1/engines/ ' ) : ]
resp = {
" id " : model_name ,
" object " : " engine " ,
" owner " : " self " ,
" ready " : True ,
}
if model_name not in pseudo_model_list + embeddings_model_list + current_model_list : # Real model only
# No args. Maybe it works anyways!
# TODO: hack some heuristics into args for better results
shared . model_name = model_name
unload_model ( )
model_settings = get_model_settings_from_yamls ( shared . model_name )
shared . settings . update ( model_settings )
update_model_parameters ( model_settings , initial = True )
if shared . settings [ ' mode ' ] != ' instruct ' :
shared . settings [ ' instruction_template ' ] = None
shared . model , shared . tokenizer = load_model ( shared . model_name )
2023-05-31 03:13:37 +02:00
2023-06-18 00:15:24 +02:00
if not shared . model : # load failed.
shared . model_name = " None "
resp [ ' id ' ] = " None "
resp [ ' ready ' ] = False
elif is_list :
# TODO: Lora's?
available_model_list = get_available_models ( )
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
models = { }
if is_legacy :
models = [ { " id " : id , " object " : " engine " , " owner " : " user " , " ready " : True } for id in all_model_list ]
if not shared . model :
models [ 0 ] [ ' ready ' ] = False
else :
models = [ { " id " : id , " object " : " model " , " owned_by " : " user " , " permission " : [ ] } for id in all_model_list ]
resp = {
2023-05-03 03:49:53 +02:00
" object " : " list " ,
" data " : models ,
2023-06-18 00:15:24 +02:00
}
2023-05-03 03:49:53 +02:00
else :
the_model_name = self . path [ len ( ' /v1/models/ ' ) : ]
2023-06-18 00:15:24 +02:00
resp = {
2023-05-03 03:49:53 +02:00
" id " : the_model_name ,
" object " : " model " ,
" owned_by " : " user " ,
" permission " : [ ]
2023-06-18 00:15:24 +02:00
}
2023-05-03 03:49:53 +02:00
2023-06-18 00:15:24 +02:00
self . send_response ( 200 )
self . send_access_control_headers ( )
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
response = json . dumps ( resp )
2023-05-24 00:58:41 +02:00
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-06-06 06:43:04 +02:00
2023-05-24 00:58:41 +02:00
elif ' /billing/usage ' in self . path :
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
self . send_response ( 200 )
2023-05-31 02:54:24 +02:00
self . send_access_control_headers ( )
2023-05-24 00:58:41 +02:00
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
response = json . dumps ( {
" total_usage " : 0 ,
} )
2023-05-03 03:49:53 +02:00
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-06-06 06:43:04 +02:00
2023-05-03 03:49:53 +02:00
else :
self . send_error ( 404 )
def do_POST ( self ) :
2023-05-03 04:05:38 +02:00
if debug :
print ( self . headers ) # did you know... python-openai sends your linux kernel & python version?
2023-05-11 16:06:39 +02:00
content_length = int ( self . headers [ ' Content-Length ' ] )
body = json . loads ( self . rfile . read ( content_length ) . decode ( ' utf-8 ' ) )
2023-05-03 04:05:38 +02:00
if debug :
print ( body )
2023-05-03 03:49:53 +02:00
if ' /completions ' in self . path or ' /generate ' in self . path :
2023-06-06 06:43:04 +02:00
if not shared . model :
self . openai_error ( " No model loaded. " )
return
2023-05-03 03:49:53 +02:00
is_legacy = ' /generate ' in self . path
2023-06-25 03:50:04 +02:00
is_chat_request = ' chat ' in self . path
2023-05-03 03:49:53 +02:00
resp_list = ' data ' if is_legacy else ' choices '
# XXX model is ignored for now
2023-05-03 04:05:38 +02:00
# model = body.get('model', shared.model_name) # ignored, use existing for now
2023-05-03 03:49:53 +02:00
model = shared . model_name
created_time = int ( time . time ( ) )
2023-05-24 00:58:41 +02:00
2023-06-25 03:50:04 +02:00
cmpl_id = " chatcmpl- %d " % ( created_time ) if is_chat_request else " conv- %d " % ( created_time )
2023-05-03 03:49:53 +02:00
2023-06-06 06:43:04 +02:00
# Request Parameters
2023-05-03 03:49:53 +02:00
# Try to use openai defaults or map them to something with the same intent
2023-06-06 06:43:04 +02:00
req_params = default_req_params . copy ( )
2023-06-25 03:50:04 +02:00
stopping_strings = [ ]
2023-06-06 06:43:04 +02:00
2023-05-03 03:49:53 +02:00
if ' stop ' in body :
if isinstance ( body [ ' stop ' ] , str ) :
2023-06-25 03:50:04 +02:00
stopping_strings . extend ( [ body [ ' stop ' ] ] )
2023-05-03 03:49:53 +02:00
elif isinstance ( body [ ' stop ' ] , list ) :
2023-06-25 03:50:04 +02:00
stopping_strings . extend ( body [ ' stop ' ] )
2023-05-03 03:49:53 +02:00
truncation_length = default ( shared . settings , ' truncation_length ' , 2048 )
truncation_length = clamp ( default ( body , ' truncation_length ' , truncation_length ) , 1 , truncation_length )
2023-06-25 03:50:04 +02:00
default_max_tokens = truncation_length if is_chat_request else 16 # completions default, chat default is 'inf' so we need to cap it.
2023-05-03 03:49:53 +02:00
max_tokens_str = ' length ' if is_legacy else ' max_tokens '
max_tokens = default ( body , max_tokens_str , default ( shared . settings , ' max_new_tokens ' , default_max_tokens ) )
2023-05-24 00:58:41 +02:00
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
2023-05-03 03:49:53 +02:00
2023-05-31 02:52:33 +02:00
req_params [ ' max_new_tokens ' ] = max_tokens
req_params [ ' truncation_length ' ] = truncation_length
req_params [ ' temperature ' ] = clamp ( default ( body , ' temperature ' , default_req_params [ ' temperature ' ] ) , 0.001 , 1.999 ) # fixup absolute 0.0
req_params [ ' top_p ' ] = clamp ( default ( body , ' top_p ' , default_req_params [ ' top_p ' ] ) , 0.001 , 1.0 )
req_params [ ' top_k ' ] = default ( body , ' best_of ' , default_req_params [ ' top_k ' ] )
req_params [ ' suffix ' ] = default ( body , ' suffix ' , default_req_params [ ' suffix ' ] )
req_params [ ' stream ' ] = default ( body , ' stream ' , default_req_params [ ' stream ' ] )
req_params [ ' echo ' ] = default ( body , ' echo ' , default_req_params [ ' echo ' ] )
req_params [ ' seed ' ] = shared . settings . get ( ' seed ' , default_req_params [ ' seed ' ] )
req_params [ ' add_bos_token ' ] = shared . settings . get ( ' add_bos_token ' , default_req_params [ ' add_bos_token ' ] )
2023-05-03 03:49:53 +02:00
2023-06-25 03:50:04 +02:00
is_streaming = req_params [ ' stream ' ]
2023-05-03 03:49:53 +02:00
self . send_response ( 200 )
2023-05-31 02:54:24 +02:00
self . send_access_control_headers ( )
2023-06-25 03:50:04 +02:00
if is_streaming :
2023-05-03 03:49:53 +02:00
self . send_header ( ' Content-Type ' , ' text/event-stream ' )
self . send_header ( ' Cache-Control ' , ' no-cache ' )
2023-05-03 04:05:38 +02:00
# self.send_header('Connection', 'keep-alive')
2023-05-03 03:49:53 +02:00
else :
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
token_count = 0
completion_token_count = 0
prompt = ' '
stream_object_type = ' '
object_type = ' '
2023-06-25 03:50:04 +02:00
if is_chat_request :
2023-05-24 00:58:41 +02:00
# Chat Completions
2023-05-03 03:49:53 +02:00
stream_object_type = ' chat.completions.chunk '
object_type = ' chat.completions '
messages = body [ ' messages ' ]
2023-05-24 00:58:41 +02:00
role_formats = {
' user ' : ' user: {message} \n ' ,
2023-05-29 19:19:57 +02:00
' assistant ' : ' assistant: {message} \n ' ,
2023-05-24 00:58:41 +02:00
' system ' : ' {message} ' ,
' context ' : ' You are a helpful assistant. Answer as concisely as possible. ' ,
' prompt ' : ' assistant: ' ,
}
# Instruct models can be much better
2023-06-18 00:15:24 +02:00
if shared . settings [ ' instruction_template ' ] :
try :
instruct = yaml . safe_load ( open ( f " characters/instruction-following/ { shared . settings [ ' instruction_template ' ] } .yaml " , ' r ' ) )
template = instruct [ ' turn_template ' ]
system_message_template = " {message} "
system_message_default = instruct [ ' context ' ]
bot_start = template . find ( ' <|bot|> ' ) # So far, 100% of instruction templates have this token
user_message_template = template [ : bot_start ] . replace ( ' <|user-message|> ' , ' {message} ' ) . replace ( ' <|user|> ' , instruct [ ' user ' ] )
bot_message_template = template [ bot_start : ] . replace ( ' <|bot-message|> ' , ' {message} ' ) . replace ( ' <|bot|> ' , instruct [ ' bot ' ] )
bot_prompt = bot_message_template [ : bot_message_template . find ( ' {message} ' ) ] . rstrip ( ' ' )
role_formats = {
' user ' : user_message_template ,
' assistant ' : bot_message_template ,
' system ' : system_message_template ,
' context ' : system_message_default ,
' prompt ' : bot_prompt ,
}
2023-06-25 03:50:04 +02:00
if ' Alpaca ' in shared . settings [ ' instruction_template ' ] :
stopping_strings . extend ( [ ' \n ### ' ] )
elif instruct [ ' user ' ] : # WizardLM and some others have no user prompt.
stopping_strings . extend ( [ ' \n ' + instruct [ ' user ' ] , instruct [ ' user ' ] ] )
2023-06-18 00:15:24 +02:00
if debug :
print ( f " Loaded instruction role format: { shared . settings [ ' instruction_template ' ] } " )
except Exception as e :
2023-06-25 03:50:04 +02:00
stopping_strings . extend ( [ ' \n user: ' ] )
2023-06-18 00:15:24 +02:00
print ( f " Exception: When loading characters/instruction-following/ { shared . settings [ ' instruction_template ' ] } .yaml: { repr ( e ) } " )
print ( " Warning: Loaded default instruction-following template for model. " )
2023-06-06 06:43:04 +02:00
2023-06-18 00:15:24 +02:00
else :
2023-06-25 03:50:04 +02:00
stopping_strings . extend ( [ ' \n user: ' ] )
2023-06-18 00:15:24 +02:00
print ( " Warning: Loaded default instruction-following template for model. " )
2023-05-03 03:49:53 +02:00
2023-05-24 00:58:41 +02:00
system_msgs = [ ]
2023-05-03 03:49:53 +02:00
chat_msgs = [ ]
2023-05-24 00:58:41 +02:00
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
context_msg = role_formats [ ' system ' ] . format ( message = role_formats [ ' context ' ] ) if role_formats [ ' context ' ] else ' '
if context_msg :
system_msgs . extend ( [ context_msg ] )
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
if ' prompt ' in body :
prompt_msg = role_formats [ ' system ' ] . format ( message = body [ ' prompt ' ] )
system_msgs . extend ( [ prompt_msg ] )
2023-05-03 03:49:53 +02:00
for m in messages :
role = m [ ' role ' ]
content = m [ ' content ' ]
2023-05-24 00:58:41 +02:00
msg = role_formats [ role ] . format ( message = content )
2023-05-03 03:49:53 +02:00
if role == ' system ' :
2023-05-24 00:58:41 +02:00
system_msgs . extend ( [ msg ] )
2023-05-03 03:49:53 +02:00
else :
2023-05-24 00:58:41 +02:00
chat_msgs . extend ( [ msg ] )
2023-05-03 04:05:38 +02:00
2023-05-24 00:58:41 +02:00
# can't really truncate the system messages
system_msg = ' \n ' . join ( system_msgs )
2023-05-31 03:07:40 +02:00
if system_msg and system_msg [ - 1 ] != ' \n ' :
2023-05-24 00:58:41 +02:00
system_msg = system_msg + ' \n '
2023-05-21 04:32:34 +02:00
2023-05-03 03:49:53 +02:00
system_token_count = len ( encode ( system_msg ) [ 0 ] )
2023-06-25 03:50:04 +02:00
remaining_tokens = truncation_length - system_token_count
2023-05-03 03:49:53 +02:00
chat_msg = ' '
2023-05-03 04:05:38 +02:00
2023-05-03 03:49:53 +02:00
while chat_msgs :
new_msg = chat_msgs . pop ( )
new_size = len ( encode ( new_msg ) [ 0 ] )
if new_size < = remaining_tokens :
chat_msg = new_msg + chat_msg
remaining_tokens - = new_size
else :
2023-05-24 00:58:41 +02:00
print ( f " Warning: too many messages for context size, dropping { len ( chat_msgs ) + 1 } oldest message(s). " )
2023-05-03 03:49:53 +02:00
break
2023-05-24 00:58:41 +02:00
prompt = system_msg + chat_msg + role_formats [ ' prompt ' ]
2023-05-03 03:49:53 +02:00
token_count = len ( encode ( prompt ) [ 0 ] )
else :
2023-05-24 00:58:41 +02:00
# Text Completions
2023-05-03 03:49:53 +02:00
stream_object_type = ' text_completion.chunk '
object_type = ' text_completion '
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
if is_legacy :
2023-05-03 04:05:38 +02:00
prompt = body [ ' context ' ] # Older engines.generate API
2023-05-03 03:49:53 +02:00
else :
2023-05-03 04:05:38 +02:00
prompt = body [ ' prompt ' ] # XXX this can be different types
2023-05-03 03:49:53 +02:00
if isinstance ( prompt , list ) :
2023-06-18 00:15:24 +02:00
self . openai_error ( " API Batched generation not yet supported. " )
return
2023-05-03 03:49:53 +02:00
token_count = len ( encode ( prompt ) [ 0 ] )
2023-06-25 03:50:04 +02:00
if token_count > = truncation_length :
2023-05-24 00:58:41 +02:00
new_len = int ( len ( prompt ) * shared . settings [ ' truncation_length ' ] / token_count )
2023-05-03 03:49:53 +02:00
prompt = prompt [ - new_len : ]
2023-05-24 00:58:41 +02:00
new_token_count = len ( encode ( prompt ) [ 0 ] )
print ( f " Warning: truncating prompt to { new_len } characters, was { token_count } tokens. Now: { new_token_count } tokens. " )
token_count = new_token_count
2023-06-25 03:50:04 +02:00
if truncation_length - token_count < req_params [ ' max_new_tokens ' ] :
print ( f " Warning: Ignoring max_new_tokens ( { req_params [ ' max_new_tokens ' ] } ), too large for the remaining context. Remaining tokens: { truncation_length - token_count } " )
req_params [ ' max_new_tokens ' ] = truncation_length - token_count
2023-05-24 00:58:41 +02:00
print ( f " Warning: Set max_new_tokens = { req_params [ ' max_new_tokens ' ] } " )
2023-05-03 03:49:53 +02:00
2023-06-25 03:50:04 +02:00
if is_streaming :
2023-05-03 03:49:53 +02:00
# begin streaming
chunk = {
" id " : cmpl_id ,
" object " : stream_object_type ,
" created " : created_time ,
" model " : shared . model_name ,
resp_list : [ {
" index " : 0 ,
" finish_reason " : None ,
} ] ,
}
if stream_object_type == ' text_completion.chunk ' :
chunk [ resp_list ] [ 0 ] [ " text " ] = " "
else :
2023-05-03 04:05:38 +02:00
# So yeah... do both methods? delta and messages.
2023-05-03 03:49:53 +02:00
chunk [ resp_list ] [ 0 ] [ " message " ] = { ' role ' : ' assistant ' , ' content ' : ' ' }
chunk [ resp_list ] [ 0 ] [ " delta " ] = { ' role ' : ' assistant ' , ' content ' : ' ' }
2023-06-06 06:43:04 +02:00
response = ' data: ' + json . dumps ( chunk ) + ' \r \n \r \n '
2023-05-03 03:49:53 +02:00
self . wfile . write ( response . encode ( ' utf-8 ' ) )
# generate reply #######################################
2023-05-03 04:05:38 +02:00
if debug :
2023-06-06 06:43:04 +02:00
print ( { ' prompt ' : prompt , ' req_params ' : req_params } )
2023-06-25 03:50:04 +02:00
generator = generate_reply ( prompt , req_params , stopping_strings = stopping_strings , is_chat = False )
2023-05-03 03:49:53 +02:00
answer = ' '
seen_content = ' '
2023-06-25 03:50:04 +02:00
longest_stop_len = max ( [ len ( x ) for x in stopping_strings ] + [ 0 ] )
2023-05-03 04:05:38 +02:00
2023-05-03 03:49:53 +02:00
for a in generator :
2023-05-11 22:07:20 +02:00
answer = a
2023-05-03 03:49:53 +02:00
stop_string_found = False
len_seen = len ( seen_content )
search_start = max ( len_seen - longest_stop_len , 0 )
2023-06-25 03:50:04 +02:00
for string in stopping_strings :
2023-05-03 03:49:53 +02:00
idx = answer . find ( string , search_start )
if idx != - 1 :
2023-05-03 04:05:38 +02:00
answer = answer [ : idx ] # clip it.
2023-05-03 03:49:53 +02:00
stop_string_found = True
if stop_string_found :
break
# If something like "\nYo" is generated just before "\nYou:"
# is completed, buffer and generate more, don't send it
buffer_and_continue = False
2023-06-25 03:50:04 +02:00
for string in stopping_strings :
2023-05-03 03:49:53 +02:00
for j in range ( len ( string ) - 1 , 0 , - 1 ) :
if answer [ - j : ] == string [ : j ] :
buffer_and_continue = True
break
else :
continue
break
if buffer_and_continue :
continue
2023-06-25 03:50:04 +02:00
if is_streaming :
2023-05-03 03:49:53 +02:00
# Streaming
new_content = answer [ len_seen : ]
2023-05-03 04:05:38 +02:00
if not new_content or chr ( 0xfffd ) in new_content : # partial unicode character, don't send it yet.
2023-05-03 03:49:53 +02:00
continue
2023-05-03 04:05:38 +02:00
2023-05-03 03:49:53 +02:00
seen_content = answer
chunk = {
" id " : cmpl_id ,
" object " : stream_object_type ,
" created " : created_time ,
" model " : shared . model_name ,
resp_list : [ {
" index " : 0 ,
" finish_reason " : None ,
} ] ,
}
2023-05-21 04:32:34 +02:00
# strip extra leading space off new generated content
if len_seen == 0 and new_content [ 0 ] == ' ' :
new_content = new_content [ 1 : ]
2023-05-03 03:49:53 +02:00
if stream_object_type == ' text_completion.chunk ' :
chunk [ resp_list ] [ 0 ] [ ' text ' ] = new_content
else :
2023-05-03 04:05:38 +02:00
# So yeah... do both methods? delta and messages.
chunk [ resp_list ] [ 0 ] [ ' message ' ] = { ' content ' : new_content }
chunk [ resp_list ] [ 0 ] [ ' delta ' ] = { ' content ' : new_content }
2023-06-06 06:43:04 +02:00
response = ' data: ' + json . dumps ( chunk ) + ' \r \n \r \n '
2023-05-03 03:49:53 +02:00
self . wfile . write ( response . encode ( ' utf-8 ' ) )
completion_token_count + = len ( encode ( new_content ) [ 0 ] )
2023-06-25 03:50:04 +02:00
if is_streaming :
2023-05-03 03:49:53 +02:00
chunk = {
" id " : cmpl_id ,
" object " : stream_object_type ,
" created " : created_time ,
2023-05-03 04:05:38 +02:00
" model " : model , # TODO: add Lora info?
2023-05-03 03:49:53 +02:00
resp_list : [ {
2023-05-10 03:49:39 +02:00
" index " : 0 ,
" finish_reason " : " stop " ,
2023-05-03 03:49:53 +02:00
} ] ,
" usage " : {
" prompt_tokens " : token_count ,
" completion_tokens " : completion_token_count ,
" total_tokens " : token_count + completion_token_count
}
}
if stream_object_type == ' text_completion.chunk ' :
chunk [ resp_list ] [ 0 ] [ ' text ' ] = ' '
else :
2023-05-03 04:05:38 +02:00
# So yeah... do both methods? delta and messages.
chunk [ resp_list ] [ 0 ] [ ' message ' ] = { ' content ' : ' ' }
2023-05-31 02:54:24 +02:00
chunk [ resp_list ] [ 0 ] [ ' delta ' ] = { ' content ' : ' ' }
2023-06-06 06:43:04 +02:00
response = ' data: ' + json . dumps ( chunk ) + ' \r \n \r \n data: [DONE] \r \n \r \n '
2023-05-03 03:49:53 +02:00
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-05-03 04:05:38 +02:00
# Finished if streaming.
if debug :
2023-05-21 04:32:34 +02:00
if answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
2023-05-24 00:58:41 +02:00
print ( { ' answer ' : answer } , chunk )
2023-05-03 03:49:53 +02:00
return
2023-05-03 04:05:38 +02:00
2023-05-21 04:32:34 +02:00
# strip extra leading space off new generated content
if answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
2023-05-03 04:05:38 +02:00
if debug :
print ( { ' response ' : answer } )
2023-05-03 03:49:53 +02:00
completion_token_count = len ( encode ( answer ) [ 0 ] )
stop_reason = " stop "
2023-06-25 03:50:04 +02:00
if token_count + completion_token_count > = truncation_length :
2023-05-03 03:49:53 +02:00
stop_reason = " length "
resp = {
" id " : cmpl_id ,
" object " : object_type ,
" created " : created_time ,
2023-05-03 04:05:38 +02:00
" model " : model , # TODO: add Lora info?
2023-05-03 03:49:53 +02:00
resp_list : [ {
" index " : 0 ,
" finish_reason " : stop_reason ,
} ] ,
" usage " : {
" prompt_tokens " : token_count ,
" completion_tokens " : completion_token_count ,
" total_tokens " : token_count + completion_token_count
}
}
2023-06-25 03:50:04 +02:00
if is_chat_request :
2023-05-03 04:05:38 +02:00
resp [ resp_list ] [ 0 ] [ " message " ] = { " role " : " assistant " , " content " : answer }
2023-05-03 03:49:53 +02:00
else :
resp [ resp_list ] [ 0 ] [ " text " ] = answer
2023-05-11 16:06:39 +02:00
response = json . dumps ( resp )
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-06-06 06:43:04 +02:00
2023-05-11 16:06:39 +02:00
elif ' /edits ' in self . path :
2023-06-06 06:43:04 +02:00
if not shared . model :
self . openai_error ( " No model loaded. " )
return
2023-05-11 16:06:39 +02:00
self . send_response ( 200 )
2023-05-31 02:54:24 +02:00
self . send_access_control_headers ( )
2023-05-11 16:06:39 +02:00
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
created_time = int ( time . time ( ) )
# Using Alpaca format, this may work with other models too.
instruction = body [ ' instruction ' ]
input = body . get ( ' input ' , ' ' )
2023-06-06 06:43:04 +02:00
# Request parameters
req_params = default_req_params . copy ( )
2023-06-25 03:50:04 +02:00
stopping_strings = [ ]
2023-06-06 06:43:04 +02:00
# Alpaca is verbose so a good default prompt
default_template = (
" Below is an instruction that describes a task, paired with an input that provides further context. "
" Write a response that appropriately completes the request. \n \n "
" ### Instruction: \n {instruction} \n \n ### Input: \n {input} \n \n ### Response: \n "
)
instruction_template = default_template
2023-06-18 00:15:24 +02:00
2023-06-06 06:43:04 +02:00
# Use the special instruction/input/response template for anything trained like Alpaca
2023-06-25 03:50:04 +02:00
if shared . settings [ ' instruction_template ' ] :
if ' Alpaca ' in shared . settings [ ' instruction_template ' ] :
stopping_strings . extend ( [ ' \n ### ' ] )
else :
try :
instruct = yaml . safe_load ( open ( f " characters/instruction-following/ { shared . settings [ ' instruction_template ' ] } .yaml " , ' r ' ) )
2023-06-06 06:43:04 +02:00
2023-06-25 03:50:04 +02:00
template = instruct [ ' turn_template ' ]
template = template \
. replace ( ' <|user|> ' , instruct . get ( ' user ' , ' ' ) ) \
. replace ( ' <|bot|> ' , instruct . get ( ' bot ' , ' ' ) ) \
. replace ( ' <|user-message|> ' , ' {instruction} \n {input} ' )
2023-06-18 00:15:24 +02:00
2023-06-25 03:50:04 +02:00
instruction_template = instruct . get ( ' context ' , ' ' ) + template [ : template . find ( ' <|bot-message|> ' ) ] . rstrip ( ' ' )
if instruct [ ' user ' ] :
stopping_strings . extend ( [ ' \n ' + instruct [ ' user ' ] , instruct [ ' user ' ] ] )
2023-06-18 00:15:24 +02:00
2023-06-25 03:50:04 +02:00
except Exception as e :
instruction_template = default_template
print ( f " Exception: When loading characters/instruction-following/ { shared . settings [ ' instruction_template ' ] } .yaml: { repr ( e ) } " )
print ( " Warning: Loaded default instruction-following template (Alpaca) for model. " )
2023-06-18 00:15:24 +02:00
else :
2023-06-25 03:50:04 +02:00
stopping_strings . extend ( [ ' \n ### ' ] )
2023-06-18 00:15:24 +02:00
print ( " Warning: Loaded default instruction-following template (Alpaca) for model. " )
2023-06-06 06:43:04 +02:00
2023-05-11 16:06:39 +02:00
edit_task = instruction_template . format ( instruction = instruction , input = input )
truncation_length = default ( shared . settings , ' truncation_length ' , 2048 )
token_count = len ( encode ( edit_task ) [ 0 ] )
max_tokens = truncation_length - token_count
2023-05-31 02:52:33 +02:00
req_params [ ' max_new_tokens ' ] = max_tokens
req_params [ ' truncation_length ' ] = truncation_length
req_params [ ' temperature ' ] = clamp ( default ( body , ' temperature ' , default_req_params [ ' temperature ' ] ) , 0.001 , 1.999 ) # fixup absolute 0.0
req_params [ ' top_p ' ] = clamp ( default ( body , ' top_p ' , default_req_params [ ' top_p ' ] ) , 0.001 , 1.0 )
req_params [ ' seed ' ] = shared . settings . get ( ' seed ' , default_req_params [ ' seed ' ] )
req_params [ ' add_bos_token ' ] = shared . settings . get ( ' add_bos_token ' , default_req_params [ ' add_bos_token ' ] )
2023-05-11 16:06:39 +02:00
if debug :
print ( { ' edit_template ' : edit_task , ' req_params ' : req_params , ' token_count ' : token_count } )
2023-06-25 03:50:04 +02:00
generator = generate_reply ( edit_task , req_params , stopping_strings = stopping_strings , is_chat = False )
2023-05-11 16:06:39 +02:00
2023-06-25 03:50:04 +02:00
longest_stop_len = max ( [ len ( x ) for x in stopping_strings ] + [ 0 ] )
2023-05-11 16:06:39 +02:00
answer = ' '
2023-06-18 00:15:24 +02:00
seen_content = ' '
2023-05-11 16:06:39 +02:00
for a in generator :
2023-05-11 22:07:20 +02:00
answer = a
2023-05-11 16:06:39 +02:00
2023-06-18 00:15:24 +02:00
stop_string_found = False
len_seen = len ( seen_content )
search_start = max ( len_seen - longest_stop_len , 0 )
2023-06-25 03:50:04 +02:00
for string in stopping_strings :
2023-06-18 00:15:24 +02:00
idx = answer . find ( string , search_start )
if idx != - 1 :
answer = answer [ : idx ] # clip it.
stop_string_found = True
if stop_string_found :
break
2023-05-14 17:57:52 +02:00
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
if edit_task [ - 1 ] != ' \n ' and answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
2023-05-11 16:06:39 +02:00
completion_token_count = len ( encode ( answer ) [ 0 ] )
resp = {
" object " : " edit " ,
" created " : created_time ,
" choices " : [ {
" text " : answer ,
" index " : 0 ,
} ] ,
" usage " : {
" prompt_tokens " : token_count ,
" completion_tokens " : completion_token_count ,
" total_tokens " : token_count + completion_token_count
}
}
if debug :
print ( { ' answer ' : answer , ' completion_token_count ' : completion_token_count } )
response = json . dumps ( resp )
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-06-06 06:43:04 +02:00
2023-05-11 16:06:39 +02:00
elif ' /images/generations ' in self . path and ' SD_WEBUI_URL ' in os . environ :
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
# Will probably work best with the stock SD models.
# SD configuration is beyond the scope of this API.
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
self . send_response ( 200 )
2023-05-31 02:54:24 +02:00
self . send_access_control_headers ( )
2023-05-11 16:06:39 +02:00
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
width , height = [ int ( x ) for x in default ( body , ' size ' , ' 1024x1024 ' ) . split ( ' x ' ) ] # ignore the restrictions on size
response_format = default ( body , ' response_format ' , ' url ' ) # or b64_json
payload = {
' prompt ' : body [ ' prompt ' ] , # ignore prompt limit of 1000 characters
' width ' : width ,
' height ' : height ,
' batch_size ' : default ( body , ' n ' , 1 ) # ignore the batch limits of max 10
}
resp = {
' created ' : int ( time . time ( ) ) ,
' data ' : [ ]
}
# TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f " { os . environ [ ' SD_WEBUI_URL ' ] } /sdapi/v1/txt2img "
response = requests . post ( url = sd_url , json = payload )
r = response . json ( )
# r['parameters']...
for b64_json in r [ ' images ' ] :
if response_format == ' b64_json ' :
resp [ ' data ' ] . extend ( [ { ' b64_json ' : b64_json } ] )
else :
resp [ ' data ' ] . extend ( [ { ' url ' : f ' data:image/png;base64, { b64_json } ' } ] ) # yeah it's lazy. requests.get() will not work with this
2023-05-03 03:49:53 +02:00
response = json . dumps ( resp )
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-06-06 06:43:04 +02:00
2023-05-03 04:05:38 +02:00
elif ' /embeddings ' in self . path and embedding_model is not None :
2023-05-03 03:49:53 +02:00
self . send_response ( 200 )
2023-05-31 02:54:24 +02:00
self . send_access_control_headers ( )
2023-05-03 03:49:53 +02:00
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
input = body [ ' input ' ] if ' input ' in body else body [ ' text ' ]
if type ( input ) is str :
input = [ input ]
embeddings = embedding_model . encode ( input ) . tolist ( )
2023-05-09 03:31:34 +02:00
def enc_emb ( emb ) :
# If base64 is specified, encode. Otherwise, do nothing.
if body . get ( " encoding_format " , " " ) == " base64 " :
return float_list_to_base64 ( emb )
else :
return emb
data = [ { " object " : " embedding " , " embedding " : enc_emb ( emb ) , " index " : n } for n , emb in enumerate ( embeddings ) ]
2023-05-03 03:49:53 +02:00
response = json . dumps ( {
" object " : " list " ,
" data " : data ,
2023-05-03 04:05:38 +02:00
" model " : st_model , # return the real model
2023-05-03 03:49:53 +02:00
" usage " : {
" prompt_tokens " : 0 ,
" total_tokens " : 0 ,
}
} )
2023-05-03 04:05:38 +02:00
if debug :
print ( f " Embeddings return size: { len ( embeddings [ 0 ] ) } , number: { len ( embeddings ) } " )
2023-05-03 03:49:53 +02:00
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-06-06 06:43:04 +02:00
2023-05-03 03:49:53 +02:00
elif ' /moderations ' in self . path :
# for now do nothing, just don't error.
self . send_response ( 200 )
2023-05-31 02:54:24 +02:00
self . send_access_control_headers ( )
2023-05-03 03:49:53 +02:00
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
response = json . dumps ( {
" id " : " modr-5MWoLO " ,
" model " : " text-moderation-001 " ,
" results " : [ {
" categories " : {
" hate " : False ,
" hate/threatening " : False ,
" self-harm " : False ,
" sexual " : False ,
" sexual/minors " : False ,
" violence " : False ,
" violence/graphic " : False
} ,
" category_scores " : {
" hate " : 0.0 ,
" hate/threatening " : 0.0 ,
" self-harm " : 0.0 ,
" sexual " : 0.0 ,
" sexual/minors " : 0.0 ,
" violence " : 0.0 ,
" violence/graphic " : 0.0
} ,
" flagged " : False
} ]
} )
self . wfile . write ( response . encode ( ' utf-8 ' ) )
elif self . path == ' /api/v1/token-count ' :
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
self . send_response ( 200 )
2023-05-31 02:54:24 +02:00
self . send_access_control_headers ( )
2023-05-03 03:49:53 +02:00
self . send_header ( ' Content-Type ' , ' application/json ' )
self . end_headers ( )
tokens = encode ( body [ ' prompt ' ] ) [ 0 ]
response = json . dumps ( {
' results ' : [ {
' tokens ' : len ( tokens )
} ]
} )
self . wfile . write ( response . encode ( ' utf-8 ' ) )
2023-06-06 06:43:04 +02:00
2023-05-03 03:49:53 +02:00
else :
print ( self . path , self . headers )
self . send_error ( 404 )
def run_server ( ) :
global embedding_model
try :
embedding_model = SentenceTransformer ( st_model )
print ( f " \n Loaded embedding model: { st_model } , max sequence length: { embedding_model . max_seq_length } " )
except :
print ( f " \n Failed to load embedding model: { st_model } " )
pass
server_addr = ( ' 0.0.0.0 ' if shared . args . listen else ' 127.0.0.1 ' , params [ ' port ' ] )
server = ThreadingHTTPServer ( server_addr , Handler )
if shared . args . share :
try :
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared ( params [ ' port ' ] , params [ ' port ' ] + 1 )
2023-05-11 16:06:39 +02:00
print ( f ' Starting OpenAI compatible api at \n OPENAI_API_BASE= { public_url } /v1 ' )
2023-05-03 03:49:53 +02:00
except ImportError :
print ( ' You should install flask_cloudflared manually ' )
else :
2023-05-11 16:06:39 +02:00
print ( f ' Starting OpenAI compatible api: \n OPENAI_API_BASE=http:// { server_addr [ 0 ] } : { server_addr [ 1 ] } /v1 ' )
2023-05-03 03:49:53 +02:00
server . serve_forever ( )
def setup ( ) :
Thread ( target = run_server , daemon = True ) . start ( )