Fix galactica equations

This commit is contained in:
oobabooga 2023-01-07 01:56:21 -03:00
parent 538998b43b
commit eeb63b1b8a

View File

@ -60,6 +60,11 @@ def fix_gpt4chan(s):
s = re.sub("--- [0-9]*\n\n\n---", "---", s) s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s return s
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
return s
def generate_reply(question, temperature, max_length, inference_settings, selected_model): def generate_reply(question, temperature, max_length, inference_settings, selected_model):
global model, tokenizer, model_name, loaded_preset, preset global model, tokenizer, model_name, loaded_preset, preset
@ -81,12 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select
output = eval(f"model.generate(input_ids, {preset}).cuda()") output = eval(f"model.generate(input_ids, {preset}).cuda()")
reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = tokenizer.decode(output[0], skip_special_tokens=True)
if model_name.startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
if model_name.lower().startswith('galactica'): if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
return reply, reply, 'Only applicable for gpt4chan.' return reply, reply, 'Only applicable for gpt4chan.'
elif model_name.lower().startswith('gpt4chan'): elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for galactica models.', generate_html(reply) return reply, 'Only applicable for galactica models.', generate_html(reply)
else: else:
return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.' return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'