mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Better way of finding the generated reply in the output string
This commit is contained in:
parent
d03b0ad7a8
commit
849e4c7f90
30
server.py
30
server.py
@ -136,14 +136,17 @@ def decode(output_ids):
|
||||
return reply
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if model_name.lower().startswith('galactica'):
|
||||
reply = fix_galactica(reply)
|
||||
return reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith('gpt4chan'):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
if not (args.chat or args.cai_chat):
|
||||
if model_name.lower().startswith('galactica'):
|
||||
reply = fix_galactica(reply)
|
||||
return reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith('gpt4chan'):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
else:
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
return reply
|
||||
|
||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
||||
global model, tokenizer, model_name, loaded_preset, preset
|
||||
@ -245,16 +248,17 @@ if args.chat or args.cai_chat:
|
||||
question = generate_chat_prompt(text, tokens, name1, name2, context)
|
||||
history.append(['', ''])
|
||||
eos_token = '\n' if check else None
|
||||
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
||||
reply = i[0]
|
||||
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
||||
next_character_found = False
|
||||
|
||||
previous_idx = [m.start() for m in re.finditer(f"\n{name2}:", question)]
|
||||
idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
|
||||
idx = idx[len(previous_idx)-1]
|
||||
reply = reply[idx + len(f"\n{name2}:"):]
|
||||
|
||||
if check:
|
||||
idx = reply.rfind(question[-1024:])
|
||||
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
|
||||
reply = reply.split('\n')[0].strip()
|
||||
else:
|
||||
idx = reply.rfind(question[-1024:])
|
||||
reply = reply[idx+min(1024, len(question)):]
|
||||
idx = reply.find(f"\n{name1}:")
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
|
Loading…
Reference in New Issue
Block a user