Clean up the streaming implementation

This commit is contained in:
oobabooga 2023-01-19 10:43:05 -03:00
parent c90310e40e
commit 93fa9bbe01
2 changed files with 34 additions and 34 deletions

View File

@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags:
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--max-gpu-memory MAX_GPU_MEMORY` | Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. | | `--max-gpu-memory MAX_GPU_MEMORY` | Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. |
| `--no-listen` | Make the web UI unreachable from your local network.| | `--no-listen` | Make the web UI unreachable from your local network.|
| `--no-stream` | Don't stream the text output in real time. This slightly improves the text generation performance.|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
## Presets ## Presets

View File

@ -25,7 +25,7 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.') parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This slightly improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
args = parser.parse_args() args = parser.parse_args()
@ -125,6 +125,21 @@ def encode(prompt, tokens):
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens) input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
return input_ids return input_ids
def decode(output_ids):
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply
def formatted_outputs(reply, model_name):
if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None): def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
global model, tokenizer, model_name, loaded_preset, preset global model, tokenizer, model_name, loaded_preset, preset
@ -141,43 +156,27 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
loaded_preset = inference_settings loaded_preset = inference_settings
cuda = "" if args.cpu else ".cuda()" cuda = "" if args.cpu else ".cuda()"
if not args.no_stream: n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
# Generate the entire reply at once
if args.no_stream:
input_ids = encode(question, tokens)
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
reply = decode(output[0])
yield formatted_outputs(reply, model_name)
# Generate the reply 1 token at a time
else:
input_ids = encode(question, 1) input_ids = encode(question, 1)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
for i in range(tokens): for i in range(tokens):
output = eval(f"model.generate(input_ids, {preset}){cuda}") output = eval(f"model.generate(input_ids, {preset}){cuda}")
reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = decode(output[0])
reply = reply.replace(r'<|endoftext|>', '')
if eos_token is not None and reply[-1] == eos_token: if eos_token is not None and reply[-1] == eos_token:
break break
if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
yield reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
yield formatted_outputs(reply, model_name)
input_ids = output input_ids = output
else:
input_ids = encode(question, tokens)
if eos_token is None:
output = eval(f"model.generate(input_ids, {preset}){cuda}")
else:
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
reply = tokenizer.decode(output[0], skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
yield reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
# Choosing the default model # Choosing the default model
if args.model is not None: if args.model is not None:
@ -206,7 +205,6 @@ else:
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}" css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
if args.chat or args.cai_chat: if args.chat or args.cai_chat:
history = [] history = []
@ -257,20 +255,21 @@ if args.chat or args.cai_chat:
reply = clean_chat_message(reply) reply = clean_chat_message(reply)
history[-1] = [text, reply] history[-1] = [text, reply]
if next_character_found:
break
# Prevent the chat log from flashing if something like "\nYo" is generated just # Prevent the chat log from flashing if something like "\nYo" is generated just
# before "\nYou:" is completed # before "\nYou:" is completed
tmp = f"\n{name1}:" tmp = f"\n{name1}:"
next_character_substring_found = False next_character_substring_found = False
for j in range(1, len(tmp)+1): for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]: if reply[-j:] == tmp[:j]:
next_character_substring_found = True next_character_substring_found = True
if not next_character_substring_found: if not next_character_substring_found:
yield history yield history
if next_character_found: yield history
break
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):