mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Further refactor
This commit is contained in:
parent
98af4bfb0d
commit
ce7feb3641
@ -6,10 +6,13 @@ from pathlib import Path
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import *
|
from modules.html_generator import generate_chat_html
|
||||||
from modules.prompt import encode
|
from modules.text_generation import encode
|
||||||
from modules.prompt import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
from modules.prompt import get_max_prompt_length
|
from modules.text_generation import get_max_prompt_length
|
||||||
|
|
||||||
|
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
|
||||||
|
import modules.bot_picture as bot_picture
|
||||||
|
|
||||||
history = {'internal': [], 'visible': []}
|
history = {'internal': [], 'visible': []}
|
||||||
character = None
|
character = None
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
global tokenizer
|
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
model_name = ""
|
model_name = ""
|
||||||
|
@ -4,7 +4,8 @@ import modules.shared as shared
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import *
|
from modules.html_generator import generate_4chan_html
|
||||||
|
from modules.html_generator import generate_basic_html
|
||||||
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
27
server.py
27
server.py
@ -20,12 +20,12 @@ from transformers import AutoTokenizer
|
|||||||
import modules.chat as chat
|
import modules.chat as chat
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
import modules.ui as ui
|
||||||
from modules.extensions import extension_state
|
from modules.extensions import extension_state
|
||||||
from modules.extensions import load_extensions
|
from modules.extensions import load_extensions
|
||||||
from modules.extensions import update_extensions_parameters
|
from modules.extensions import update_extensions_parameters
|
||||||
from modules.html_generator import *
|
from modules.html_generator import generate_chat_html
|
||||||
from modules.prompt import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
from modules.ui import *
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -74,9 +74,6 @@ if shared.args.deepspeed:
|
|||||||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||||
|
|
||||||
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
|
|
||||||
import modules.bot_picture as bot_picture
|
|
||||||
|
|
||||||
def load_model(model_name):
|
def load_model(model_name):
|
||||||
print(f"Loading {model_name}...")
|
print(f"Loading {model_name}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
@ -288,11 +285,11 @@ def create_settings_menus():
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
||||||
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
|
ui.create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||||
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
|
ui.create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
|
||||||
|
|
||||||
with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"):
|
with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -320,7 +317,7 @@ def create_settings_menus():
|
|||||||
with gr.Accordion("Soft prompt", open=False, elem_id="accordion"):
|
with gr.Accordion("Soft prompt", open=False, elem_id="accordion"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
|
softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
|
||||||
create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
|
ui.create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
|
||||||
|
|
||||||
gr.Markdown('Upload a soft prompt (.zip format):')
|
gr.Markdown('Upload a soft prompt (.zip format):')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -336,8 +333,9 @@ def create_settings_menus():
|
|||||||
available_models = get_available_models()
|
available_models = get_available_models()
|
||||||
available_presets = get_available_presets()
|
available_presets = get_available_presets()
|
||||||
available_characters = get_available_characters()
|
available_characters = get_available_characters()
|
||||||
extensions_module.available_extensions = get_available_extensions()
|
|
||||||
available_softprompts = get_available_softprompts()
|
available_softprompts = get_available_softprompts()
|
||||||
|
|
||||||
|
extensions_module.available_extensions = get_available_extensions()
|
||||||
if shared.args.extensions is not None:
|
if shared.args.extensions is not None:
|
||||||
load_extensions()
|
load_extensions()
|
||||||
|
|
||||||
@ -359,7 +357,6 @@ else:
|
|||||||
print()
|
print()
|
||||||
shared.model_name = available_models[i]
|
shared.model_name = available_models[i]
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
loaded_preset = None
|
|
||||||
|
|
||||||
# UI settings
|
# UI settings
|
||||||
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
||||||
@ -379,7 +376,7 @@ if shared.args.chat or shared.args.cai_chat:
|
|||||||
if Path(f'logs/persistent.json').exists():
|
if Path(f'logs/persistent.json').exists():
|
||||||
chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), settings[f'name1{suffix}'], settings[f'name2{suffix}'])
|
chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), settings[f'name1{suffix}'], settings[f'name2{suffix}'])
|
||||||
|
|
||||||
with gr.Blocks(css=css+chat_css, analytics_enabled=False) as interface:
|
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface:
|
||||||
if shared.args.cai_chat:
|
if shared.args.cai_chat:
|
||||||
display = gr.HTML(value=generate_chat_html(chat.history['visible'], settings[f'name1{suffix}'], settings[f'name2{suffix}'], chat.character))
|
display = gr.HTML(value=generate_chat_html(chat.history['visible'], settings[f'name1{suffix}'], settings[f'name2{suffix}'], chat.character))
|
||||||
else:
|
else:
|
||||||
@ -406,7 +403,7 @@ if shared.args.chat or shared.args.cai_chat:
|
|||||||
context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
|
context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
|
character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
|
||||||
create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
|
ui.create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
||||||
@ -489,7 +486,7 @@ if shared.args.chat or shared.args.cai_chat:
|
|||||||
upload_img_me.upload(lambda : chat.history['visible'], [], [display])
|
upload_img_me.upload(lambda : chat.history['visible'], [], [display])
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
with gr.Blocks(css=ui.css, analytics_enabled=False) as interface:
|
||||||
gr.Markdown(description)
|
gr.Markdown(description)
|
||||||
with gr.Tab('Raw'):
|
with gr.Tab('Raw'):
|
||||||
textbox = gr.Textbox(value=default_text, lines=23)
|
textbox = gr.Textbox(value=default_text, lines=23)
|
||||||
@ -513,7 +510,7 @@ elif shared.args.notebook:
|
|||||||
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
buttons["Stop"].click(None, None, None, cancels=gen_events)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
with gr.Blocks(css=ui.css, analytics_enabled=False) as interface:
|
||||||
gr.Markdown(description)
|
gr.Markdown(description)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
Loading…
Reference in New Issue
Block a user