From 364529d0c7405b92aba90aba4204944818fd4f13 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 23 Feb 2023 14:31:28 -0300 Subject: [PATCH] Further refactor --- server.py | 44 +++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/server.py b/server.py index 791d98a3..0cd21fcd 100644 --- a/server.py +++ b/server.py @@ -25,11 +25,27 @@ from modules.text_generation import generate_reply if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n") +# Loading custom settings if shared.args.settings is not None and Path(shared.args.settings).exists(): new_settings = json.loads(open(Path(shared.args.settings), 'r').read()) for item in new_settings: shared.settings[item] = new_settings[item] +def get_available_models(): + return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower) + +def get_available_presets(): + return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) + +def get_available_characters(): + return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) + +def get_available_extensions(): + return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) + +def get_available_softprompts(): + return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) + def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model @@ -82,21 +98,6 @@ def upload_soft_prompt(file): return name -def get_available_models(): - return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower) - -def get_available_presets(): - return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) - -def get_available_characters(): - return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) - -def get_available_extensions(): - return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) - -def get_available_softprompts(): - return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) - def create_extensions_block(): extensions_ui_elements = [] default_values = [] @@ -171,7 +172,6 @@ def create_settings_menus(): upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu]) return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping -# Global variables available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() @@ -192,7 +192,7 @@ else: i = 0 else: print("The following models are available:\n") - for i,model in enumerate(available_models): + for i, model in enumerate(available_models): print(f"{i+1}. {model}") print(f"\nWhich one do you want to load? 1-{len(available_models)}\n") i = int(input())-1 @@ -201,20 +201,18 @@ else: shared.model, shared.tokenizer = load_model(shared.model_name) # UI settings +buttons = {} +gen_events = [] +suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' +description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): default_text = shared.settings['prompt_gpt4chan'] elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None: default_text = 'User: \n' else: default_text = shared.settings['prompt'] -description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" - -suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' -buttons = {} -gen_events = [] if shared.args.chat or shared.args.cai_chat: - if Path(f'logs/persistent.json').exists(): chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'])