mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-28 18:48:04 +01:00
Further refactor
This commit is contained in:
parent
e46c43afa6
commit
364529d0c7
44
server.py
44
server.py
@ -25,11 +25,27 @@ from modules.text_generation import generate_reply
|
|||||||
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
|
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
|
||||||
print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n")
|
print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n")
|
||||||
|
|
||||||
|
# Loading custom settings
|
||||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||||
new_settings = json.loads(open(Path(shared.args.settings), 'r').read())
|
new_settings = json.loads(open(Path(shared.args.settings), 'r').read())
|
||||||
for item in new_settings:
|
for item in new_settings:
|
||||||
shared.settings[item] = new_settings[item]
|
shared.settings[item] = new_settings[item]
|
||||||
|
|
||||||
|
def get_available_models():
|
||||||
|
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
|
||||||
|
|
||||||
|
def get_available_presets():
|
||||||
|
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
||||||
|
|
||||||
|
def get_available_characters():
|
||||||
|
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
||||||
|
|
||||||
|
def get_available_extensions():
|
||||||
|
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||||
|
|
||||||
|
def get_available_softprompts():
|
||||||
|
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||||
|
|
||||||
def load_model_wrapper(selected_model):
|
def load_model_wrapper(selected_model):
|
||||||
if selected_model != shared.model_name:
|
if selected_model != shared.model_name:
|
||||||
shared.model_name = selected_model
|
shared.model_name = selected_model
|
||||||
@ -82,21 +98,6 @@ def upload_soft_prompt(file):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def get_available_models():
|
|
||||||
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
|
|
||||||
|
|
||||||
def get_available_presets():
|
|
||||||
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
|
||||||
|
|
||||||
def get_available_characters():
|
|
||||||
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
|
||||||
|
|
||||||
def get_available_extensions():
|
|
||||||
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
|
||||||
|
|
||||||
def get_available_softprompts():
|
|
||||||
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
|
|
||||||
|
|
||||||
def create_extensions_block():
|
def create_extensions_block():
|
||||||
extensions_ui_elements = []
|
extensions_ui_elements = []
|
||||||
default_values = []
|
default_values = []
|
||||||
@ -171,7 +172,6 @@ def create_settings_menus():
|
|||||||
upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu])
|
upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu])
|
||||||
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
||||||
|
|
||||||
# Global variables
|
|
||||||
available_models = get_available_models()
|
available_models = get_available_models()
|
||||||
available_presets = get_available_presets()
|
available_presets = get_available_presets()
|
||||||
available_characters = get_available_characters()
|
available_characters = get_available_characters()
|
||||||
@ -192,7 +192,7 @@ else:
|
|||||||
i = 0
|
i = 0
|
||||||
else:
|
else:
|
||||||
print("The following models are available:\n")
|
print("The following models are available:\n")
|
||||||
for i,model in enumerate(available_models):
|
for i, model in enumerate(available_models):
|
||||||
print(f"{i+1}. {model}")
|
print(f"{i+1}. {model}")
|
||||||
print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
|
print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
|
||||||
i = int(input())-1
|
i = int(input())-1
|
||||||
@ -201,20 +201,18 @@ else:
|
|||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
# UI settings
|
# UI settings
|
||||||
|
buttons = {}
|
||||||
|
gen_events = []
|
||||||
|
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
||||||
|
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
|
||||||
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
||||||
default_text = shared.settings['prompt_gpt4chan']
|
default_text = shared.settings['prompt_gpt4chan']
|
||||||
elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None:
|
elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None:
|
||||||
default_text = 'User: \n'
|
default_text = 'User: \n'
|
||||||
else:
|
else:
|
||||||
default_text = shared.settings['prompt']
|
default_text = shared.settings['prompt']
|
||||||
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
|
|
||||||
|
|
||||||
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
|
||||||
buttons = {}
|
|
||||||
gen_events = []
|
|
||||||
|
|
||||||
if shared.args.chat or shared.args.cai_chat:
|
if shared.args.chat or shared.args.cai_chat:
|
||||||
|
|
||||||
if Path(f'logs/persistent.json').exists():
|
if Path(f'logs/persistent.json').exists():
|
||||||
chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'])
|
chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user