mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
103 lines
4.1 KiB
Python
103 lines
4.1 KiB
Python
import time
|
|
import yaml
|
|
import os
|
|
from modules import shared
|
|
from extensions.openai.defaults import get_default_req_params
|
|
from extensions.openai.utils import debug_msg
|
|
from extensions.openai.errors import *
|
|
from modules.text_generation import encode, generate_reply
|
|
|
|
|
|
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
|
|
|
|
created_time = int(time.time() * 1000)
|
|
|
|
# Request parameters
|
|
req_params = get_default_req_params()
|
|
stopping_strings = []
|
|
|
|
# Alpaca is verbose so a good default prompt
|
|
default_template = (
|
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
"Write a response that appropriately completes the request.\n\n"
|
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
)
|
|
|
|
instruction_template = default_template
|
|
|
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
|
if shared.settings['instruction_template']:
|
|
if 'Alpaca' in shared.settings['instruction_template']:
|
|
stopping_strings.extend(['\n###'])
|
|
else:
|
|
try:
|
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
|
|
template = instruct['turn_template']
|
|
template = template\
|
|
.replace('<|user|>', instruct.get('user', ''))\
|
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
|
|
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
|
if instruct['user']:
|
|
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
|
|
|
except Exception as e:
|
|
instruction_template = default_template
|
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
else:
|
|
stopping_strings.extend(['\n###'])
|
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
|
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
|
|
|
truncation_length = shared.settings['truncation_length']
|
|
|
|
token_count = len(encode(edit_task)[0])
|
|
max_tokens = truncation_length - token_count
|
|
|
|
if max_tokens < 1:
|
|
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
|
raise InvalidRequestError(err_msg, param='input')
|
|
|
|
req_params['max_new_tokens'] = max_tokens
|
|
req_params['truncation_length'] = truncation_length
|
|
req_params['temperature'] = temperature
|
|
req_params['top_p'] = top_p
|
|
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
|
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
|
|
|
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
|
|
|
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
|
|
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
|
answer = ''
|
|
for a in generator:
|
|
answer = a
|
|
|
|
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
|
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
|
answer = answer[1:]
|
|
|
|
completion_token_count = len(encode(answer)[0])
|
|
|
|
resp = {
|
|
"object": "edit",
|
|
"created": created_time,
|
|
"choices": [{
|
|
"text": answer,
|
|
"index": 0,
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": token_count,
|
|
"completion_tokens": completion_token_count,
|
|
"total_tokens": token_count + completion_token_count
|
|
}
|
|
}
|
|
|
|
return resp
|