mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Refactor text_generation.py, add support for custom generation functions (#1817)
This commit is contained in:
parent
876fbb97c0
commit
8aafb1f796
@ -45,7 +45,9 @@ Most of these have been created by the extremely talented contributors that you
|
|||||||
| `def ui()` | Creates custom gradio elements when the UI is launched. |
|
| `def ui()` | Creates custom gradio elements when the UI is launched. |
|
||||||
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
|
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
|
||||||
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
|
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
|
||||||
|
| `def state_modifier(state)` | Modifies the dictionary containing the input parameters before it is used by the text generation functions. |
|
||||||
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
|
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
|
||||||
|
| `def custom_generate_reply(...)` | Overrides the main text generation function. |
|
||||||
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
||||||
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
|
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
|
||||||
|
|
||||||
@ -104,6 +106,23 @@ python server.py --extensions enthusiasm translate # First apply enthusiasm, the
|
|||||||
python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm
|
python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## `custom_generate_reply` example
|
||||||
|
|
||||||
|
Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def custom_generate_reply(question, original_question, seed, state, eos_token, stopping_strings):
|
||||||
|
cumulative = ''
|
||||||
|
for i in range(10):
|
||||||
|
cumulative += f"Counting: {i}...\n"
|
||||||
|
yield cumulative
|
||||||
|
|
||||||
|
cumulative += f"Done! {str(datetime.datetime.now())}"
|
||||||
|
yield cumulative
|
||||||
|
```
|
||||||
|
|
||||||
## `custom_generate_chat_prompt` example
|
## `custom_generate_chat_prompt` example
|
||||||
|
|
||||||
Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode.
|
Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode.
|
||||||
@ -114,51 +133,64 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
is_instruct = state['mode'] == 'instruct'
|
is_instruct = state['mode'] == 'instruct'
|
||||||
rows = [f"{state['context'].strip()}\n"]
|
rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"]
|
||||||
|
min_rows = 3
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
chat_prompt_size = state['chat_prompt_size']
|
chat_prompt_size = state['chat_prompt_size']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||||
|
|
||||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||||
|
|
||||||
if is_instruct:
|
# Building the turn templates
|
||||||
prefix1 = f"{state['name1']}\n"
|
if 'turn_template' not in state or state['turn_template'] == '':
|
||||||
prefix2 = f"{state['name2']}\n"
|
if is_instruct:
|
||||||
|
template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n'
|
||||||
|
else:
|
||||||
|
template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n'
|
||||||
else:
|
else:
|
||||||
prefix1 = f"{state['name1']}: "
|
template = state['turn_template'].replace(r'\n', '\n')
|
||||||
prefix2 = f"{state['name2']}: "
|
|
||||||
|
|
||||||
|
replacements = {
|
||||||
|
'<|user|>': state['name1'].strip(),
|
||||||
|
'<|bot|>': state['name2'].strip(),
|
||||||
|
}
|
||||||
|
|
||||||
|
user_turn = replace_all(template.split('<|bot|>')[0], replacements)
|
||||||
|
bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements)
|
||||||
|
user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements)
|
||||||
|
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
|
||||||
|
|
||||||
|
# Building the prompt
|
||||||
i = len(shared.history['internal']) - 1
|
i = len(shared.history['internal']) - 1
|
||||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||||
if _continue and i == len(shared.history['internal']) - 1:
|
if _continue and i == len(shared.history['internal']) - 1:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
|
||||||
else:
|
else:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
|
||||||
|
|
||||||
string = shared.history['internal'][i][0]
|
string = shared.history['internal'][i][0]
|
||||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
|
||||||
|
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|
||||||
if impersonate:
|
if impersonate:
|
||||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
min_rows = 2
|
||||||
limit = 2
|
rows.append(user_turn_stripped.rstrip(' '))
|
||||||
elif _continue:
|
elif not _continue:
|
||||||
limit = 3
|
|
||||||
else:
|
|
||||||
# Adding the user message
|
# Adding the user message
|
||||||
user_input = fix_newlines(user_input)
|
|
||||||
if len(user_input) > 0:
|
if len(user_input) > 0:
|
||||||
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
|
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
|
||||||
|
|
||||||
# Adding the Character prefix
|
# Adding the Character prefix
|
||||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||||
limit = 3
|
|
||||||
|
|
||||||
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
|
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||||
rows.pop(1)
|
rows.pop(1)
|
||||||
prompt = ''.join(rows)
|
|
||||||
|
|
||||||
|
prompt = ''.join(rows)
|
||||||
if also_return_rows:
|
if also_return_rows:
|
||||||
return prompt, rows
|
return prompt, rows
|
||||||
else:
|
else:
|
||||||
|
@ -33,6 +33,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
prompt = body['prompt']
|
prompt = body['prompt']
|
||||||
generate_params = build_parameters(body)
|
generate_params = build_parameters(body)
|
||||||
stopping_strings = generate_params.pop('stopping_strings')
|
stopping_strings = generate_params.pop('stopping_strings')
|
||||||
|
generate_params['stream'] = False
|
||||||
|
|
||||||
generator = generate_reply(
|
generator = generate_reply(
|
||||||
prompt, generate_params, stopping_strings=stopping_strings)
|
prompt, generate_params, stopping_strings=stopping_strings)
|
||||||
@ -66,7 +67,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
|
|
||||||
def _run_server(port: int, share: bool=False):
|
def _run_server(port: int, share: bool = False):
|
||||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||||
|
|
||||||
server = ThreadingHTTPServer((address, port), Handler)
|
server = ThreadingHTTPServer((address, port), Handler)
|
||||||
|
@ -23,6 +23,7 @@ async def _handle_connection(websocket, path):
|
|||||||
prompt = message['prompt']
|
prompt = message['prompt']
|
||||||
generate_params = build_parameters(message)
|
generate_params = build_parameters(message)
|
||||||
stopping_strings = generate_params.pop('stopping_strings')
|
stopping_strings = generate_params.pop('stopping_strings')
|
||||||
|
generate_params['stream'] = True
|
||||||
|
|
||||||
generator = generate_reply(
|
generator = generate_reply(
|
||||||
prompt, generate_params, stopping_strings=stopping_strings)
|
prompt, generate_params, stopping_strings=stopping_strings)
|
||||||
|
@ -18,15 +18,8 @@ wav_idx = 0
|
|||||||
user = ElevenLabsUser(params['api_key'])
|
user = ElevenLabsUser(params['api_key'])
|
||||||
user_info = None
|
user_info = None
|
||||||
|
|
||||||
if not shared.args.no_stream:
|
|
||||||
print("Please add --no-stream. This extension is not meant to be used with streaming.")
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
# Check if the API is valid and refresh the UI accordingly.
|
# Check if the API is valid and refresh the UI accordingly.
|
||||||
|
|
||||||
|
|
||||||
def check_valid_api():
|
def check_valid_api():
|
||||||
|
|
||||||
global user, user_info, params
|
global user, user_info, params
|
||||||
|
|
||||||
user = ElevenLabsUser(params['api_key'])
|
user = ElevenLabsUser(params['api_key'])
|
||||||
@ -41,9 +34,8 @@ def check_valid_api():
|
|||||||
print('Got an API Key!')
|
print('Got an API Key!')
|
||||||
return gr.update(value='Connected')
|
return gr.update(value='Connected')
|
||||||
|
|
||||||
|
|
||||||
# Once the API is verified, get the available voices and update the dropdown list
|
# Once the API is verified, get the available voices and update the dropdown list
|
||||||
|
|
||||||
|
|
||||||
def refresh_voices():
|
def refresh_voices():
|
||||||
|
|
||||||
global user, user_info
|
global user, user_info
|
||||||
@ -63,6 +55,11 @@ def remove_surrounded_chars(string):
|
|||||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
|
|
||||||
|
def state_modifier(state):
|
||||||
|
state['stream'] = False
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@ -109,6 +106,7 @@ def ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||||
connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
|
connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
|
||||||
|
|
||||||
voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
|
voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
||||||
|
@ -266,8 +266,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
stopping_strings += standard_stopping_strings
|
stopping_strings += standard_stopping_strings
|
||||||
req_params['custom_stopping_strings'] = stopping_strings
|
req_params['custom_stopping_strings'] = stopping_strings
|
||||||
|
|
||||||
shared.args.no_stream = not req_params['stream']
|
if req_params['stream']:
|
||||||
if not shared.args.no_stream:
|
|
||||||
shared.args.chat = True
|
shared.args.chat = True
|
||||||
# begin streaming
|
# begin streaming
|
||||||
chunk = {
|
chunk = {
|
||||||
@ -337,7 +336,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if buffer_and_continue:
|
if buffer_and_continue:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not shared.args.no_stream:
|
if req_params['stream']:
|
||||||
# Streaming
|
# Streaming
|
||||||
new_content = answer[len_seen:]
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
@ -365,7 +364,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
completion_token_count += len(encode(new_content)[0])
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
|
||||||
if not shared.args.no_stream:
|
if req_params['stream']:
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": stream_object_type,
|
"object": stream_object_type,
|
||||||
|
@ -75,7 +75,6 @@ if params['manage_VRAM']:
|
|||||||
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
|
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
|
||||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||||
|
|
||||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
|
||||||
picture_response = False # specifies if the next model response should appear as a picture
|
picture_response = False # specifies if the next model response should appear as a picture
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
@ -92,6 +91,13 @@ def triggers_are_in(string):
|
|||||||
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
|
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
|
||||||
|
|
||||||
|
|
||||||
|
def state_modifier(state):
|
||||||
|
if picture_response:
|
||||||
|
state['stream'] = False
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@ -218,14 +224,13 @@ def bot_prefix_modifier(string):
|
|||||||
|
|
||||||
|
|
||||||
def toggle_generation(*args):
|
def toggle_generation(*args):
|
||||||
global picture_response, shared, streaming_state
|
global picture_response, shared
|
||||||
|
|
||||||
if not args:
|
if not args:
|
||||||
picture_response = not picture_response
|
picture_response = not picture_response
|
||||||
else:
|
else:
|
||||||
picture_response = args[0]
|
picture_response = args[0]
|
||||||
|
|
||||||
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
|
||||||
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
|
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,5 +43,5 @@ def ui():
|
|||||||
picture_select.upload(
|
picture_select.upload(
|
||||||
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
|
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||||
lambda: None, None, picture_select, show_progress=False)
|
lambda: None, None, picture_select, show_progress=False)
|
||||||
|
@ -29,7 +29,6 @@ current_params = params.copy()
|
|||||||
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||||
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||||
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
|
||||||
|
|
||||||
# Used for making text xml compatible, needed for voice pitch and speed control
|
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||||
table = str.maketrans({
|
table = str.maketrans({
|
||||||
@ -76,6 +75,11 @@ def toggle_text_in_history(name1, name2, mode):
|
|||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def state_modifier(state):
|
||||||
|
state['stream'] = False
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@ -87,7 +91,6 @@ def input_modifier(string):
|
|||||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
|
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
|
||||||
|
|
||||||
shared.processing_message = "*Is recording a voice message...*"
|
shared.processing_message = "*Is recording a voice message...*"
|
||||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +127,6 @@ def output_modifier(string):
|
|||||||
string += f'\n\n{original_string}'
|
string += f'\n\n{original_string}'
|
||||||
|
|
||||||
shared.processing_message = "*Is typing...*"
|
shared.processing_message = "*Is typing...*"
|
||||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,6 +86,15 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Extension that modifies the input parameters before they are used
|
||||||
|
def _apply_state_modifier_extensions(state):
|
||||||
|
for extension, _ in iterator():
|
||||||
|
if hasattr(extension, "state_modifier"):
|
||||||
|
state = getattr(extension, "state_modifier")(state)
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
# Extension functions that override the default tokenizer output
|
# Extension functions that override the default tokenizer output
|
||||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||||
for extension, _ in iterator():
|
for extension, _ in iterator():
|
||||||
@ -95,13 +104,24 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e
|
|||||||
return prompt, input_ids, input_embeds
|
return prompt, input_ids, input_embeds
|
||||||
|
|
||||||
|
|
||||||
|
# Custom generate reply handling
|
||||||
|
def _apply_custom_generate_reply():
|
||||||
|
for extension, _ in iterator():
|
||||||
|
if hasattr(extension, 'custom_generate_reply'):
|
||||||
|
return getattr(extension, 'custom_generate_reply')
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
EXTENSION_MAP = {
|
EXTENSION_MAP = {
|
||||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||||
|
"state": _apply_state_modifier_extensions,
|
||||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||||
"input_hijack": _apply_input_hijack,
|
"input_hijack": _apply_input_hijack,
|
||||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
|
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||||
|
"custom_generate_reply": _apply_custom_generate_reply
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ def get_max_prompt_length(state):
|
|||||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||||
|
|
||||||
return max_length
|
return max_length
|
||||||
|
|
||||||
|
|
||||||
@ -62,6 +63,36 @@ def decode(output_ids, skip_special_tokens=True):
|
|||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_softprompt_input_tensors(input_ids):
|
||||||
|
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||||
|
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||||
|
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||||
|
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Removes empty replies from gpt4chan outputs
|
||||||
|
def fix_gpt4chan(s):
|
||||||
|
for i in range(10):
|
||||||
|
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||||
|
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||||
|
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# Fix the LaTeX equations in galactica
|
||||||
|
def fix_galactica(s):
|
||||||
|
s = s.replace(r'\[', r'$')
|
||||||
|
s = s.replace(r'\]', r'$')
|
||||||
|
s = s.replace(r'\(', r'$')
|
||||||
|
s = s.replace(r'\)', r'$')
|
||||||
|
s = s.replace(r'$$', r'$')
|
||||||
|
s = re.sub(r'\n', r'\n\n', s)
|
||||||
|
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
||||||
if shared.model_type == 'HF_seq2seq':
|
if shared.model_type == 'HF_seq2seq':
|
||||||
reply = decode(output_ids, state['skip_special_tokens'])
|
reply = decode(output_ids, state['skip_special_tokens'])
|
||||||
@ -81,35 +112,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
|||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def generate_softprompt_input_tensors(input_ids):
|
|
||||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
|
||||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
|
||||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
|
||||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
|
||||||
return inputs_embeds, filler_input_ids
|
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
|
||||||
def fix_gpt4chan(s):
|
|
||||||
for i in range(10):
|
|
||||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
|
||||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
|
||||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
# Fix the LaTeX equations in galactica
|
|
||||||
def fix_galactica(s):
|
|
||||||
s = s.replace(r'\[', r'$')
|
|
||||||
s = s.replace(r'\]', r'$')
|
|
||||||
s = s.replace(r'\(', r'$')
|
|
||||||
s = s.replace(r'\)', r'$')
|
|
||||||
s = s.replace(r'$$', r'$')
|
|
||||||
s = re.sub(r'\n', r'\n\n', s)
|
|
||||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
def formatted_outputs(reply, model_name):
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
if shared.model_type == 'galactica':
|
if shared.model_type == 'galactica':
|
||||||
@ -140,51 +142,21 @@ def stop_everything_event():
|
|||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
|
|
||||||
def get_generate_params(state):
|
|
||||||
generate_params = {}
|
|
||||||
|
|
||||||
# Models that are not on transformers
|
|
||||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
|
||||||
generate_params['token_count'] = state['max_new_tokens']
|
|
||||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
|
||||||
generate_params[k] = state[k]
|
|
||||||
else:
|
|
||||||
# FlexGen
|
|
||||||
if shared.args.flexgen:
|
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
|
||||||
generate_params[k] = state[k]
|
|
||||||
|
|
||||||
if not shared.args.no_stream:
|
|
||||||
generate_params['max_new_tokens'] = 8
|
|
||||||
|
|
||||||
# transformers
|
|
||||||
else:
|
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
|
||||||
generate_params[k] = state[k]
|
|
||||||
|
|
||||||
if state['ban_eos_token']:
|
|
||||||
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
if shared.args.no_cache:
|
|
||||||
generate_params.update({'use_cache': False})
|
|
||||||
|
|
||||||
if shared.args.deepspeed:
|
|
||||||
generate_params.update({'synced_gpus': True})
|
|
||||||
|
|
||||||
return generate_params
|
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
state = apply_extensions('state', state)
|
||||||
logging.error("No model is loaded! Select one in the Model tab.")
|
generate_func = apply_extensions('custom_generate_reply')
|
||||||
yield formatted_outputs(question, shared.model_name)
|
if generate_func is None:
|
||||||
return
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
|
logging.error("No model is loaded! Select one in the Model tab.")
|
||||||
|
yield formatted_outputs(question, shared.model_name)
|
||||||
|
return
|
||||||
|
|
||||||
clear_torch_cache()
|
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||||
seed = set_manual_seed(state['seed'])
|
generate_func = generate_reply_custom
|
||||||
shared.stop_everything = False
|
elif shared.args.flexgen:
|
||||||
generate_params = get_generate_params(state)
|
generate_func = generate_reply_flexgen
|
||||||
t0 = time.time()
|
else:
|
||||||
|
generate_func = generate_reply_HF
|
||||||
|
|
||||||
# Preparing the input
|
# Preparing the input
|
||||||
original_question = question
|
original_question = question
|
||||||
@ -194,42 +166,31 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
if shared.args.verbose:
|
if shared.args.verbose:
|
||||||
print(f'\n\n{question}\n--------------------\n')
|
print(f'\n\n{question}\n--------------------\n')
|
||||||
|
|
||||||
# If the model is not on transformers, handle it separately and end this
|
shared.stop_everything = False
|
||||||
# function call earlier.
|
clear_torch_cache()
|
||||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
seed = set_manual_seed(state['seed'])
|
||||||
|
for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings):
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
try:
|
|
||||||
if shared.args.no_stream:
|
|
||||||
reply = shared.model.generate(context=question, **generate_params)
|
|
||||||
output = original_question + reply
|
|
||||||
if not shared.is_chat():
|
|
||||||
reply = original_question + apply_extensions('output', reply)
|
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||||
else:
|
generate_params = {}
|
||||||
if not shared.is_chat():
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||||
yield formatted_outputs(question, shared.model_name)
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
if state['ban_eos_token']:
|
||||||
output = original_question + reply
|
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||||
if not shared.is_chat():
|
|
||||||
reply = original_question + apply_extensions('output', reply)
|
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
if shared.args.no_cache:
|
||||||
|
generate_params.update({'use_cache': False})
|
||||||
|
|
||||||
except Exception:
|
if shared.args.deepspeed:
|
||||||
traceback.print_exc()
|
generate_params.update({'synced_gpus': True})
|
||||||
finally:
|
|
||||||
t1 = time.time()
|
|
||||||
original_tokens = len(encode(original_question)[0])
|
|
||||||
new_tokens = len(encode(output)[0]) - original_tokens
|
|
||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
|
||||||
return
|
|
||||||
|
|
||||||
# Encode the input
|
# Encode the input
|
||||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||||
|
|
||||||
# Find the eos tokens
|
# Find the eos tokens
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
@ -259,15 +220,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Update generate_params with the eos token and the stopping strings
|
# Update generate_params with the eos token and the stopping strings
|
||||||
if shared.args.flexgen:
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params['stop'] = eos_token_ids[-1]
|
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||||
else:
|
|
||||||
generate_params['eos_token_id'] = eos_token_ids
|
|
||||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
|
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
||||||
|
yield original_question
|
||||||
|
|
||||||
# Generate the entire reply at once.
|
# Generate the entire reply at once.
|
||||||
if shared.args.no_stream:
|
if not state['stream']:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = shared.model.generate(**generate_params)[0]
|
output = shared.model.generate(**generate_params)[0]
|
||||||
if cuda:
|
if cuda:
|
||||||
@ -276,12 +238,11 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
|
||||||
|
|
||||||
# Stream the reply 1 token at a time.
|
# Stream the reply 1 token at a time.
|
||||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||||
elif not shared.args.flexgen:
|
else:
|
||||||
|
|
||||||
def generate_with_callback(callback=None, **kwargs):
|
def generate_with_callback(callback=None, **kwargs):
|
||||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||||
@ -292,45 +253,118 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
def generate_with_streaming(**kwargs):
|
def generate_with_streaming(**kwargs):
|
||||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||||
|
|
||||||
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
|
||||||
yield formatted_outputs(original_question, shared.model_name)
|
|
||||||
|
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
with generate_with_streaming(**generate_params) as generator:
|
||||||
for output in generator:
|
for output in generator:
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
finally:
|
||||||
else:
|
t1 = time.time()
|
||||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
original_tokens = len(original_input_ids[0])
|
||||||
clear_torch_cache()
|
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
||||||
with torch.no_grad():
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
output = shared.model.generate(**generate_params)[0]
|
return
|
||||||
|
|
||||||
if shared.soft_prompt:
|
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||||
|
seed = set_manual_seed(state['seed'])
|
||||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
generate_params = {'token_count': state['max_new_tokens']}
|
||||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||||
break
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
t0 = time.time()
|
||||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
try:
|
||||||
if shared.soft_prompt:
|
if not shared.is_chat():
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
yield question
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
|
||||||
generate_params.update({'inputs': filler_input_ids})
|
if not state['stream']:
|
||||||
else:
|
reply = shared.model.generate(context=question, **generate_params)
|
||||||
generate_params.update({'inputs': input_ids})
|
output = original_question + reply
|
||||||
|
if not shared.is_chat():
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
yield reply
|
||||||
|
else:
|
||||||
|
|
||||||
|
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||||
|
output = original_question + reply
|
||||||
|
if not shared.is_chat():
|
||||||
|
reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
yield reply
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
t1 = time.time()
|
||||||
|
original_tokens = len(encode(original_question)[0])
|
||||||
|
new_tokens = len(encode(output)[0]) - original_tokens
|
||||||
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||||
|
generate_params = {}
|
||||||
|
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||||
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
|
if state['stream']:
|
||||||
|
generate_params['max_new_tokens'] = 8
|
||||||
|
|
||||||
|
# Encode the input
|
||||||
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
|
output = input_ids[0]
|
||||||
|
|
||||||
|
# Find the eos tokens
|
||||||
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
|
if eos_token is not None:
|
||||||
|
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||||
|
|
||||||
|
# Add the encoded tokens to generate_params
|
||||||
|
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||||
|
original_input_ids = input_ids
|
||||||
|
generate_params.update({'inputs': input_ids})
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
|
||||||
|
# Update generate_params with the eos token and the stopping strings
|
||||||
|
generate_params['stop'] = eos_token_ids[-1]
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
try:
|
||||||
|
if not shared.is_chat():
|
||||||
|
yield question
|
||||||
|
|
||||||
|
# Generate the entire reply at once.
|
||||||
|
if not state['stream']:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = shared.model.generate(**generate_params)[0]
|
||||||
|
|
||||||
|
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||||
|
|
||||||
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
|
else:
|
||||||
|
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||||
|
if shared.stop_everything:
|
||||||
|
break
|
||||||
|
|
||||||
|
clear_torch_cache()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = shared.model.generate(**generate_params)[0]
|
||||||
|
|
||||||
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||||
|
break
|
||||||
|
|
||||||
|
yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
|
||||||
|
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||||
|
generate_params.update({'inputs': input_ids})
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
|
@ -34,7 +34,7 @@ def list_model_elements():
|
|||||||
|
|
||||||
|
|
||||||
def list_interface_input_elements(chat=False):
|
def list_interface_input_elements(chat=False):
|
||||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
|
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream']
|
||||||
if chat:
|
if chat:
|
||||||
elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
|
elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
|
||||||
|
|
||||||
|
28
server.py
28
server.py
@ -15,6 +15,7 @@ def my_get(url, **kwargs):
|
|||||||
kwargs.setdefault('allow_redirects', True)
|
kwargs.setdefault('allow_redirects', True)
|
||||||
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
||||||
|
|
||||||
|
|
||||||
original_get = requests.get
|
original_get = requests.get
|
||||||
requests.get = my_get
|
requests.get = my_get
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -454,6 +455,7 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||||
|
|
||||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||||
|
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||||
|
|
||||||
with gr.Accordion('Soft prompt', open=False):
|
with gr.Accordion('Soft prompt', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -721,46 +723,46 @@ def create_interface():
|
|||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Continue'].click(
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(
|
gen_events.append(shared.gradio['Impersonate'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
shared.gradio['Replace last reply'].click(
|
shared.gradio['Replace last reply'].click(
|
||||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
|
||||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
shared.gradio['Send dummy message'].click(
|
shared.gradio['Send dummy message'].click(
|
||||||
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
|
||||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
shared.gradio['Send dummy reply'].click(
|
shared.gradio['Send dummy reply'].click(
|
||||||
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
|
||||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
@ -786,7 +788,7 @@ def create_interface():
|
|||||||
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=False)
|
||||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||||
@ -808,14 +810,14 @@ def create_interface():
|
|||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
|
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
|
||||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
|
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
|
||||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -824,13 +826,13 @@ def create_interface():
|
|||||||
gen_events.append(shared.gradio['Regenerate'].click(
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then(
|
lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
|
||||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
gen_events.append(shared.gradio['Continue'].click(
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream) # .then(
|
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False) # .then(
|
||||||
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user