mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Refactor text_generation.py, add support for custom generation functions (#1817)
This commit is contained in:
parent
876fbb97c0
commit
8aafb1f796
@ -45,7 +45,9 @@ Most of these have been created by the extremely talented contributors that you
|
||||
| `def ui()` | Creates custom gradio elements when the UI is launched. |
|
||||
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
|
||||
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
|
||||
| `def state_modifier(state)` | Modifies the dictionary containing the input parameters before it is used by the text generation functions. |
|
||||
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
|
||||
| `def custom_generate_reply(...)` | Overrides the main text generation function. |
|
||||
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
||||
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
|
||||
|
||||
@ -104,6 +106,23 @@ python server.py --extensions enthusiasm translate # First apply enthusiasm, the
|
||||
python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm
|
||||
```
|
||||
|
||||
## `custom_generate_reply` example
|
||||
|
||||
Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet.
|
||||
|
||||
```python
|
||||
import datetime
|
||||
|
||||
def custom_generate_reply(question, original_question, seed, state, eos_token, stopping_strings):
|
||||
cumulative = ''
|
||||
for i in range(10):
|
||||
cumulative += f"Counting: {i}...\n"
|
||||
yield cumulative
|
||||
|
||||
cumulative += f"Done! {str(datetime.datetime.now())}"
|
||||
yield cumulative
|
||||
```
|
||||
|
||||
## `custom_generate_chat_prompt` example
|
||||
|
||||
Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode.
|
||||
@ -114,51 +133,64 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
is_instruct = state['mode'] == 'instruct'
|
||||
rows = [f"{state['context'].strip()}\n"]
|
||||
rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"]
|
||||
min_rows = 3
|
||||
|
||||
# Finding the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
# Building the turn templates
|
||||
if 'turn_template' not in state or state['turn_template'] == '':
|
||||
if is_instruct:
|
||||
prefix1 = f"{state['name1']}\n"
|
||||
prefix2 = f"{state['name2']}\n"
|
||||
template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n'
|
||||
else:
|
||||
prefix1 = f"{state['name1']}: "
|
||||
prefix2 = f"{state['name2']}: "
|
||||
template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n'
|
||||
else:
|
||||
template = state['turn_template'].replace(r'\n', '\n')
|
||||
|
||||
replacements = {
|
||||
'<|user|>': state['name1'].strip(),
|
||||
'<|bot|>': state['name2'].strip(),
|
||||
}
|
||||
|
||||
user_turn = replace_all(template.split('<|bot|>')[0], replacements)
|
||||
bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements)
|
||||
user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements)
|
||||
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
|
||||
|
||||
# Building the prompt
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
|
||||
else:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
||||
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
|
||||
|
||||
string = shared.history['internal'][i][0]
|
||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
||||
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
|
||||
|
||||
i -= 1
|
||||
|
||||
if impersonate:
|
||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||
limit = 2
|
||||
elif _continue:
|
||||
limit = 3
|
||||
else:
|
||||
min_rows = 2
|
||||
rows.append(user_turn_stripped.rstrip(' '))
|
||||
elif not _continue:
|
||||
# Adding the user message
|
||||
user_input = fix_newlines(user_input)
|
||||
if len(user_input) > 0:
|
||||
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
limit = 3
|
||||
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||
|
||||
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
prompt = ''.join(rows)
|
||||
|
||||
prompt = ''.join(rows)
|
||||
if also_return_rows:
|
||||
return prompt, rows
|
||||
else:
|
||||
|
@ -33,6 +33,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
prompt = body['prompt']
|
||||
generate_params = build_parameters(body)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = False
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings)
|
||||
|
@ -23,6 +23,7 @@ async def _handle_connection(websocket, path):
|
||||
prompt = message['prompt']
|
||||
generate_params = build_parameters(message)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = True
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings)
|
||||
|
@ -18,15 +18,8 @@ wav_idx = 0
|
||||
user = ElevenLabsUser(params['api_key'])
|
||||
user_info = None
|
||||
|
||||
if not shared.args.no_stream:
|
||||
print("Please add --no-stream. This extension is not meant to be used with streaming.")
|
||||
raise ValueError
|
||||
|
||||
# Check if the API is valid and refresh the UI accordingly.
|
||||
|
||||
|
||||
def check_valid_api():
|
||||
|
||||
global user, user_info, params
|
||||
|
||||
user = ElevenLabsUser(params['api_key'])
|
||||
@ -41,9 +34,8 @@ def check_valid_api():
|
||||
print('Got an API Key!')
|
||||
return gr.update(value='Connected')
|
||||
|
||||
|
||||
# Once the API is verified, get the available voices and update the dropdown list
|
||||
|
||||
|
||||
def refresh_voices():
|
||||
|
||||
global user, user_info
|
||||
@ -63,6 +55,11 @@ def remove_surrounded_chars(string):
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
|
||||
def state_modifier(state):
|
||||
state['stream'] = False
|
||||
return state
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
@ -109,6 +106,7 @@ def ui():
|
||||
with gr.Row():
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||
connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
|
||||
|
||||
voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
|
||||
with gr.Row():
|
||||
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
||||
|
@ -266,8 +266,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
stopping_strings += standard_stopping_strings
|
||||
req_params['custom_stopping_strings'] = stopping_strings
|
||||
|
||||
shared.args.no_stream = not req_params['stream']
|
||||
if not shared.args.no_stream:
|
||||
if req_params['stream']:
|
||||
shared.args.chat = True
|
||||
# begin streaming
|
||||
chunk = {
|
||||
@ -337,7 +336,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
if buffer_and_continue:
|
||||
continue
|
||||
|
||||
if not shared.args.no_stream:
|
||||
if req_params['stream']:
|
||||
# Streaming
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
@ -365,7 +364,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
|
||||
if not shared.args.no_stream:
|
||||
if req_params['stream']:
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
|
@ -75,7 +75,6 @@ if params['manage_VRAM']:
|
||||
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
|
||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
@ -92,6 +91,13 @@ def triggers_are_in(string):
|
||||
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
|
||||
|
||||
|
||||
def state_modifier(state):
|
||||
if picture_response:
|
||||
state['stream'] = False
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
@ -218,14 +224,13 @@ def bot_prefix_modifier(string):
|
||||
|
||||
|
||||
def toggle_generation(*args):
|
||||
global picture_response, shared, streaming_state
|
||||
global picture_response, shared
|
||||
|
||||
if not args:
|
||||
picture_response = not picture_response
|
||||
else:
|
||||
picture_response = args[0]
|
||||
|
||||
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
|
||||
|
||||
|
||||
|
@ -43,5 +43,5 @@ def ui():
|
||||
picture_select.upload(
|
||||
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||
lambda: None, None, picture_select, show_progress=False)
|
||||
|
@ -29,7 +29,6 @@ current_params = params.copy()
|
||||
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
|
||||
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||
table = str.maketrans({
|
||||
@ -76,6 +75,11 @@ def toggle_text_in_history(name1, name2, mode):
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def state_modifier(state):
|
||||
state['stream'] = False
|
||||
return state
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
@ -87,7 +91,6 @@ def input_modifier(string):
|
||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
|
||||
|
||||
shared.processing_message = "*Is recording a voice message...*"
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
||||
return string
|
||||
|
||||
|
||||
@ -124,7 +127,6 @@ def output_modifier(string):
|
||||
string += f'\n\n{original_string}'
|
||||
|
||||
shared.processing_message = "*Is typing...*"
|
||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
||||
return string
|
||||
|
||||
|
||||
|
@ -86,6 +86,15 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
# Extension that modifies the input parameters before they are used
|
||||
def _apply_state_modifier_extensions(state):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, "state_modifier"):
|
||||
state = getattr(extension, "state_modifier")(state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
@ -95,13 +104,24 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
# Custom generate reply handling
|
||||
def _apply_custom_generate_reply():
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'custom_generate_reply'):
|
||||
return getattr(extension, 'custom_generate_reply')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
EXTENSION_MAP = {
|
||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"state": _apply_state_modifier_extensions,
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||
"custom_generate_reply": _apply_custom_generate_reply
|
||||
}
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ def get_max_prompt_length(state):
|
||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||
if shared.soft_prompt:
|
||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
return max_length
|
||||
|
||||
|
||||
@ -62,6 +63,36 @@ def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
s = s.replace(r'\(', r'$')
|
||||
s = s.replace(r'\)', r'$')
|
||||
s = s.replace(r'$$', r'$')
|
||||
s = re.sub(r'\n', r'\n\n', s)
|
||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||
return s
|
||||
|
||||
|
||||
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
||||
if shared.model_type == 'HF_seq2seq':
|
||||
reply = decode(output_ids, state['skip_special_tokens'])
|
||||
@ -81,35 +112,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
||||
return reply
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
s = s.replace(r'\(', r'$')
|
||||
s = s.replace(r'\)', r'$')
|
||||
s = s.replace(r'$$', r'$')
|
||||
s = re.sub(r'\n', r'\n\n', s)
|
||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||
return s
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if not shared.is_chat():
|
||||
if shared.model_type == 'galactica':
|
||||
@ -140,25 +142,39 @@ def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
|
||||
def get_generate_params(state):
|
||||
generate_params = {}
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
state = apply_extensions('state', state)
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logging.error("No model is loaded! Select one in the Model tab.")
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
return
|
||||
|
||||
# Models that are not on transformers
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
generate_params['token_count'] = state['max_new_tokens']
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
generate_func = generate_reply_custom
|
||||
elif shared.args.flexgen:
|
||||
generate_func = generate_reply_flexgen
|
||||
else:
|
||||
# FlexGen
|
||||
if shared.args.flexgen:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = state[k]
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
if not shared.args.no_stream:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
# Preparing the input
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions('input', question)
|
||||
|
||||
# transformers
|
||||
else:
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings):
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
@ -171,65 +187,10 @@ def get_generate_params(state):
|
||||
if shared.args.deepspeed:
|
||||
generate_params.update({'synced_gpus': True})
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logging.error("No model is loaded! Select one in the Model tab.")
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
return
|
||||
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
shared.stop_everything = False
|
||||
generate_params = get_generate_params(state)
|
||||
t0 = time.time()
|
||||
|
||||
# Preparing the input
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions('input', question)
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
# If the model is not on transformers, handle it separately and end this
|
||||
# function call earlier.
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(encode(original_question)[0])
|
||||
new_tokens = len(encode(output)[0]) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||
|
||||
# Find the eos tokens
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
@ -259,15 +220,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
break
|
||||
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
if shared.args.flexgen:
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
else:
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
||||
yield original_question
|
||||
|
||||
# Generate the entire reply at once.
|
||||
if shared.args.no_stream:
|
||||
if not state['stream']:
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
if cuda:
|
||||
@ -276,12 +238,11 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||
elif not shared.args.flexgen:
|
||||
else:
|
||||
|
||||
def generate_with_callback(callback=None, **kwargs):
|
||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||
@ -292,45 +253,118 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||
|
||||
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
|
||||
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||
seed = set_manual_seed(state['seed'])
|
||||
generate_params = {'token_count': state['max_new_tokens']}
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not shared.is_chat():
|
||||
yield question
|
||||
|
||||
if not state['stream']:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield reply
|
||||
else:
|
||||
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield reply
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(encode(original_question)[0])
|
||||
new_tokens = len(encode(output)[0]) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
|
||||
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['stream']:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
|
||||
# Find the eos tokens
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not shared.is_chat():
|
||||
yield question
|
||||
|
||||
# Generate the entire reply at once.
|
||||
if not state['stream']:
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
||||
yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
|
@ -34,7 +34,7 @@ def list_model_elements():
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
|
||||
|
||||
|
28
server.py
28
server.py
@ -15,6 +15,7 @@ def my_get(url, **kwargs):
|
||||
kwargs.setdefault('allow_redirects', True)
|
||||
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
||||
|
||||
|
||||
original_get = requests.get
|
||||
requests.get = my_get
|
||||
import gradio as gr
|
||||
@ -454,6 +455,7 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||
|
||||
with gr.Accordion('Soft prompt', open=False):
|
||||
with gr.Row():
|
||||
@ -721,46 +723,46 @@ def create_interface():
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Impersonate'].click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False)
|
||||
)
|
||||
|
||||
shared.gradio['Replace last reply'].click(
|
||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
|
||||
shared.gradio['Send dummy message'].click(
|
||||
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
|
||||
shared.gradio['Send dummy reply'].click(
|
||||
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
|
||||
@ -786,7 +788,7 @@ def create_interface():
|
||||
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
@ -808,14 +810,14 @@ def create_interface():
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
|
||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
|
||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
@ -824,13 +826,13 @@ def create_interface():
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
|
||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
else:
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream) # .then(
|
||||
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False) # .then(
|
||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user