From 6456777b09b472657605307617385dda8857b243 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 16 Jan 2023 16:35:45 -0300 Subject: [PATCH] Clean things up --- README.md | 2 +- convert-to-torch.py | 4 +--- download-model.py | 1 - html_generator.py | 4 +--- server.py | 28 +++++++++++++--------------- 5 files changed, 16 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index ed250351..0075bde0 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ Optionally, you can use the following command-line flags: | `--load-in-8bit` | Load the model with 8-bit precision.| | `--max-gpu-memory MAX_GPU_MEMORY` | Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. | | `--no-listen` | Make the web UI unreachable from your local network.| -| `--settings-file SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| +| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| ## Presets diff --git a/convert-to-torch.py b/convert-to-torch.py index ab07bbcf..8159b67f 100644 --- a/convert-to-torch.py +++ b/convert-to-torch.py @@ -17,7 +17,5 @@ model_name = path.name print(f"Loading {model_name}...") model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() -print("Model loaded.") - -print(f"Saving to torch-dumps/{model_name}.pt") +print(f"Model loaded.\nSaving to torch-dumps/{model_name}.pt") torch.save(model, Path(f"torch-dumps/{model_name}.pt")) diff --git a/download-model.py b/download-model.py index 8b3d2502..9733fcfa 100644 --- a/download-model.py +++ b/download-model.py @@ -28,7 +28,6 @@ def get_file(args): t.close() if __name__ == '__main__': - model = argv[1] if model[-1] == '/': model = model[:-1] diff --git a/html_generator.py b/html_generator.py index 71eb299e..630b0d04 100644 --- a/html_generator.py +++ b/html_generator.py @@ -1,6 +1,6 @@ ''' -This is a library for formatting gpt4chan outputs as nice HTML. +This is a library for formatting GPT-4chan and chat outputs as nice HTML. ''' @@ -267,7 +267,5 @@ def generate_chat_html(history, name1, name2): """ - output += '' output += "" - return output diff --git a/server.py b/server.py index fabd3d93..2b76510b 100644 --- a/server.py +++ b/server.py @@ -25,14 +25,12 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.') -parser.add_argument('--settings-file', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') +parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') args = parser.parse_args() loaded_preset = None -available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))))) -available_models = [item for item in available_models if not item.endswith('.txt')] -available_models = sorted(available_models, key=str.lower) -available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt'))))) +available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower) +available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], Path('presets').glob('*.txt'))), key=str.lower) settings = { 'max_new_tokens': 200, @@ -50,12 +48,12 @@ settings = { 'stop_at_newline': True, } -if args.settings_file is not None and Path(args.settings_file).exists(): - with open(Path(args.settings_file), 'r') as f: +if args.settings is not None and Path(args.settings).exists(): + with open(Path(args.settings), 'r') as f: new_settings = json.load(f) - for i in new_settings: - if i in settings: - settings[i] = new_settings[i] + for item in new_settings: + if item in settings: + settings[item] = new_settings[item] def load_model(model_name): print(f"Loading {model_name}...") @@ -87,7 +85,7 @@ def load_model(model_name): else: settings.append("torch_dtype=torch.float16") - settings = ', '.join(list(set(settings))) + settings = ', '.join(set(settings)) command = f"{command}(Path(f'models/{model_name}'), {settings})" model = eval(command) @@ -109,7 +107,7 @@ def fix_gpt4chan(s): s = re.sub("--- [0-9]*\n\n\n---", "---", s) return s -# Fix the LaTeX equations in GALACTICA +# Fix the LaTeX equations in galactica def fix_galactica(s): s = s.replace(r'\[', r'$') s = s.replace(r'\]', r'$') @@ -154,9 +152,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok return reply, reply, generate_basic_html(reply) elif model_name.lower().startswith('gpt4chan'): reply = fix_gpt4chan(reply) - return reply, 'Only applicable for galactica models.', generate_4chan_html(reply) + return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) else: - return reply, 'Only applicable for galactica models.', generate_basic_html(reply) + return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) # Choosing the default model if args.model is not None: @@ -219,7 +217,7 @@ elif args.chat or args.cai_chat: def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): text = chat_response_cleaner(text) - question = context+'\n\n' + question = f"{context}\n\n" for i in range(len(history)): if args.cai_chat: question += f"{name1}: {history[i][0].strip()}\n"