mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Better way of finding the generated reply in the output string
This commit is contained in:
parent
d03b0ad7a8
commit
849e4c7f90
30
server.py
30
server.py
@ -136,14 +136,17 @@ def decode(output_ids):
|
|||||||
return reply
|
return reply
|
||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
def formatted_outputs(reply, model_name):
|
||||||
if model_name.lower().startswith('galactica'):
|
if not (args.chat or args.cai_chat):
|
||||||
reply = fix_galactica(reply)
|
if model_name.lower().startswith('galactica'):
|
||||||
return reply, reply, generate_basic_html(reply)
|
reply = fix_galactica(reply)
|
||||||
elif model_name.lower().startswith('gpt4chan'):
|
return reply, reply, generate_basic_html(reply)
|
||||||
reply = fix_gpt4chan(reply)
|
elif model_name.lower().startswith('gpt4chan'):
|
||||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
reply = fix_gpt4chan(reply)
|
||||||
|
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||||
|
else:
|
||||||
|
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||||
else:
|
else:
|
||||||
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
return reply
|
||||||
|
|
||||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
||||||
global model, tokenizer, model_name, loaded_preset, preset
|
global model, tokenizer, model_name, loaded_preset, preset
|
||||||
@ -245,16 +248,17 @@ if args.chat or args.cai_chat:
|
|||||||
question = generate_chat_prompt(text, tokens, name1, name2, context)
|
question = generate_chat_prompt(text, tokens, name1, name2, context)
|
||||||
history.append(['', ''])
|
history.append(['', ''])
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
||||||
reply = i[0]
|
|
||||||
next_character_found = False
|
next_character_found = False
|
||||||
|
|
||||||
|
previous_idx = [m.start() for m in re.finditer(f"\n{name2}:", question)]
|
||||||
|
idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
|
||||||
|
idx = idx[len(previous_idx)-1]
|
||||||
|
reply = reply[idx + len(f"\n{name2}:"):]
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
idx = reply.rfind(question[-1024:])
|
reply = reply.split('\n')[0].strip()
|
||||||
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
|
|
||||||
else:
|
else:
|
||||||
idx = reply.rfind(question[-1024:])
|
|
||||||
reply = reply[idx+min(1024, len(question)):]
|
|
||||||
idx = reply.find(f"\n{name1}:")
|
idx = reply.find(f"\n{name1}:")
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
reply = reply[:idx]
|
reply = reply[:idx]
|
||||||
|
Loading…
Reference in New Issue
Block a user