From ce7feb3641519852d43443399b466255277a27a8 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 23 Feb 2023 13:03:52 -0300 Subject: [PATCH] Further refactor --- modules/chat.py | 11 +++++---- modules/shared.py | 2 -- modules/{prompt.py => text_generation.py} | 3 ++- server.py | 27 ++++++++++------------- 4 files changed, 21 insertions(+), 22 deletions(-) rename modules/{prompt.py => text_generation.py} (98%) diff --git a/modules/chat.py b/modules/chat.py index d02acaae..8e054f17 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -6,10 +6,13 @@ from pathlib import Path import modules.shared as shared from modules.extensions import apply_extensions -from modules.html_generator import * -from modules.prompt import encode -from modules.prompt import generate_reply -from modules.prompt import get_max_prompt_length +from modules.html_generator import generate_chat_html +from modules.text_generation import encode +from modules.text_generation import generate_reply +from modules.text_generation import get_max_prompt_length + +if shared.args.picture and (shared.args.cai_chat or shared.args.chat): + import modules.bot_picture as bot_picture history = {'internal': [], 'visible': []} character = None diff --git a/modules/shared.py b/modules/shared.py index 72622894..3f9a1035 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,7 +1,5 @@ import argparse -global tokenizer - model = None tokenizer = None model_name = "" diff --git a/modules/prompt.py b/modules/text_generation.py similarity index 98% rename from modules/prompt.py rename to modules/text_generation.py index b95897aa..42fe3869 100644 --- a/modules/prompt.py +++ b/modules/text_generation.py @@ -4,7 +4,8 @@ import modules.shared as shared import torch import transformers from modules.extensions import apply_extensions -from modules.html_generator import * +from modules.html_generator import generate_4chan_html +from modules.html_generator import generate_basic_html from modules.stopping_criteria import _SentinelTokenStoppingCriteria from tqdm import tqdm diff --git a/server.py b/server.py index d1a0a01f..2ce0b40c 100644 --- a/server.py +++ b/server.py @@ -20,12 +20,12 @@ from transformers import AutoTokenizer import modules.chat as chat import modules.extensions as extensions_module import modules.shared as shared +import modules.ui as ui from modules.extensions import extension_state from modules.extensions import load_extensions from modules.extensions import update_extensions_parameters -from modules.html_generator import * -from modules.prompt import generate_reply -from modules.ui import * +from modules.html_generator import generate_chat_html +from modules.text_generation import generate_reply transformers.logging.set_verbosity_error() @@ -74,9 +74,6 @@ if shared.args.deepspeed: ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir) dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration -if shared.args.picture and (shared.args.cai_chat or shared.args.chat): - import modules.bot_picture as bot_picture - def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() @@ -288,11 +285,11 @@ def create_settings_menus(): with gr.Column(): with gr.Row(): model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') - create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") + ui.create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") with gr.Column(): with gr.Row(): preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset') - create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") + ui.create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"): with gr.Row(): @@ -320,7 +317,7 @@ def create_settings_menus(): with gr.Accordion("Soft prompt", open=False, elem_id="accordion"): with gr.Row(): softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt') - create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") + ui.create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") gr.Markdown('Upload a soft prompt (.zip format):') with gr.Row(): @@ -336,8 +333,9 @@ def create_settings_menus(): available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() -extensions_module.available_extensions = get_available_extensions() available_softprompts = get_available_softprompts() + +extensions_module.available_extensions = get_available_extensions() if shared.args.extensions is not None: load_extensions() @@ -359,7 +357,6 @@ else: print() shared.model_name = available_models[i] shared.model, shared.tokenizer = load_model(shared.model_name) -loaded_preset = None # UI settings if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): @@ -379,7 +376,7 @@ if shared.args.chat or shared.args.cai_chat: if Path(f'logs/persistent.json').exists(): chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), settings[f'name1{suffix}'], settings[f'name2{suffix}']) - with gr.Blocks(css=css+chat_css, analytics_enabled=False) as interface: + with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface: if shared.args.cai_chat: display = gr.HTML(value=generate_chat_html(chat.history['visible'], settings[f'name1{suffix}'], settings[f'name2{suffix}'], chat.character)) else: @@ -406,7 +403,7 @@ if shared.args.chat or shared.args.cai_chat: context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context') with gr.Row(): character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character') - create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button") + ui.create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button") with gr.Row(): check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') @@ -489,7 +486,7 @@ if shared.args.chat or shared.args.cai_chat: upload_img_me.upload(lambda : chat.history['visible'], [], [display]) elif shared.args.notebook: - with gr.Blocks(css=css, analytics_enabled=False) as interface: + with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: gr.Markdown(description) with gr.Tab('Raw'): textbox = gr.Textbox(value=default_text, lines=23) @@ -513,7 +510,7 @@ elif shared.args.notebook: buttons["Stop"].click(None, None, None, cancels=gen_events) else: - with gr.Blocks(css=css, analytics_enabled=False) as interface: + with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: gr.Markdown(description) with gr.Row(): with gr.Column():