2023-07-11 17:50:08 -04:00
import time
import yaml
import os
from modules import shared
from extensions . openai . defaults import get_default_req_params
from extensions . openai . utils import debug_msg
from extensions . openai . errors import *
from modules . text_generation import encode , generate_reply
2023-07-12 11:33:25 -07:00
def edits ( instruction : str , input : str , temperature = 1.0 , top_p = 1.0 ) - > dict :
2023-07-11 17:50:08 -04:00
2023-07-12 11:33:25 -07:00
created_time = int ( time . time ( ) * 1000 )
2023-07-11 17:50:08 -04:00
# Request parameters
req_params = get_default_req_params ( )
stopping_strings = [ ]
# Alpaca is verbose so a good default prompt
default_template = (
" Below is an instruction that describes a task, paired with an input that provides further context. "
" Write a response that appropriately completes the request. \n \n "
" ### Instruction: \n {instruction} \n \n ### Input: \n {input} \n \n ### Response: \n "
)
instruction_template = default_template
2023-07-12 11:33:25 -07:00
2023-07-11 17:50:08 -04:00
# Use the special instruction/input/response template for anything trained like Alpaca
if shared . settings [ ' instruction_template ' ] :
if ' Alpaca ' in shared . settings [ ' instruction_template ' ] :
stopping_strings . extend ( [ ' \n ### ' ] )
else :
try :
instruct = yaml . safe_load ( open ( f " characters/instruction-following/ { shared . settings [ ' instruction_template ' ] } .yaml " , ' r ' ) )
template = instruct [ ' turn_template ' ]
template = template \
. replace ( ' <|user|> ' , instruct . get ( ' user ' , ' ' ) ) \
. replace ( ' <|bot|> ' , instruct . get ( ' bot ' , ' ' ) ) \
. replace ( ' <|user-message|> ' , ' {instruction} \n {input} ' )
instruction_template = instruct . get ( ' context ' , ' ' ) + template [ : template . find ( ' <|bot-message|> ' ) ] . rstrip ( ' ' )
if instruct [ ' user ' ] :
2023-07-12 11:33:25 -07:00
stopping_strings . extend ( [ ' \n ' + instruct [ ' user ' ] , instruct [ ' user ' ] ] )
2023-07-11 17:50:08 -04:00
except Exception as e :
instruction_template = default_template
print ( f " Exception: When loading characters/instruction-following/ { shared . settings [ ' instruction_template ' ] } .yaml: { repr ( e ) } " )
print ( " Warning: Loaded default instruction-following template (Alpaca) for model. " )
else :
stopping_strings . extend ( [ ' \n ### ' ] )
print ( " Warning: Loaded default instruction-following template (Alpaca) for model. " )
edit_task = instruction_template . format ( instruction = instruction , input = input )
truncation_length = shared . settings [ ' truncation_length ' ]
2023-07-12 11:33:25 -07:00
2023-07-11 17:50:08 -04:00
token_count = len ( encode ( edit_task ) [ 0 ] )
max_tokens = truncation_length - token_count
if max_tokens < 1 :
err_msg = f " This model maximum context length is { truncation_length } tokens. However, your messages resulted in over { truncation_length - max_tokens } tokens. "
raise InvalidRequestError ( err_msg , param = ' input ' )
2023-07-12 11:33:25 -07:00
2023-07-11 17:50:08 -04:00
req_params [ ' max_new_tokens ' ] = max_tokens
req_params [ ' truncation_length ' ] = truncation_length
req_params [ ' temperature ' ] = temperature
req_params [ ' top_p ' ] = top_p
req_params [ ' seed ' ] = shared . settings . get ( ' seed ' , req_params [ ' seed ' ] )
req_params [ ' add_bos_token ' ] = shared . settings . get ( ' add_bos_token ' , req_params [ ' add_bos_token ' ] )
req_params [ ' custom_stopping_strings ' ] = shared . settings [ ' custom_stopping_strings ' ]
debug_msg ( { ' edit_template ' : edit_task , ' req_params ' : req_params , ' token_count ' : token_count } )
2023-07-12 11:33:25 -07:00
2023-07-11 17:50:08 -04:00
generator = generate_reply ( edit_task , req_params , stopping_strings = stopping_strings , is_chat = False )
longest_stop_len = max ( [ len ( x ) for x in stopping_strings ] + [ 0 ] )
answer = ' '
for a in generator :
answer = a
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
if edit_task [ - 1 ] != ' \n ' and answer and answer [ 0 ] == ' ' :
answer = answer [ 1 : ]
completion_token_count = len ( encode ( answer ) [ 0 ] )
resp = {
" object " : " edit " ,
" created " : created_time ,
" choices " : [ {
" text " : answer ,
" index " : 0 ,
} ] ,
" usage " : {
" prompt_tokens " : token_count ,
" completion_tokens " : completion_token_count ,
" total_tokens " : token_count + completion_token_count
}
}
return resp