From eeb63b1b8a9f75139ede2f842e1ed4e533ac5d40 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 7 Jan 2023 01:56:21 -0300 Subject: [PATCH] Fix galactica equations --- server.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/server.py b/server.py index 8e416ce8..b9f83e2d 100644 --- a/server.py +++ b/server.py @@ -60,6 +60,11 @@ def fix_gpt4chan(s): s = re.sub("--- [0-9]*\n\n\n---", "---", s) return s +def fix_galactica(s): + s = s.replace(r'\[', r'$') + s = s.replace(r'\]', r'$') + return s + def generate_reply(question, temperature, max_length, inference_settings, selected_model): global model, tokenizer, model_name, loaded_preset, preset @@ -81,12 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select output = eval(f"model.generate(input_ids, {preset}).cuda()") reply = tokenizer.decode(output[0], skip_special_tokens=True) - if model_name.startswith('gpt4chan'): - reply = fix_gpt4chan(reply) - if model_name.lower().startswith('galactica'): + reply = fix_galactica(reply) return reply, reply, 'Only applicable for gpt4chan.' elif model_name.lower().startswith('gpt4chan'): + reply = fix_gpt4chan(reply) return reply, 'Only applicable for galactica models.', generate_html(reply) else: return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'