mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-11 21:10:40 +01:00
Better variable names
This commit is contained in:
parent
fc0493d885
commit
6be571cff7
55
server.py
55
server.py
@ -115,17 +115,17 @@ def load_model(model_name):
|
|||||||
# Custom
|
# Custom
|
||||||
else:
|
else:
|
||||||
command = "AutoModelForCausalLM.from_pretrained"
|
command = "AutoModelForCausalLM.from_pretrained"
|
||||||
settings = ["low_cpu_mem_usage=True"]
|
params = ["low_cpu_mem_usage=True"]
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
settings.append("low_cpu_mem_usage=True")
|
params.append("low_cpu_mem_usage=True")
|
||||||
settings.append("torch_dtype=torch.float32")
|
params.append("torch_dtype=torch.float32")
|
||||||
else:
|
else:
|
||||||
settings.append("device_map='auto'")
|
params.append("device_map='auto'")
|
||||||
settings.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.float16")
|
params.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.float16")
|
||||||
|
|
||||||
if args.gpu_memory:
|
if args.gpu_memory:
|
||||||
settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
|
params.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
|
||||||
elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit:
|
elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit:
|
||||||
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
|
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
|
||||||
suggestion = round((total_mem-1000)/1000)*1000
|
suggestion = round((total_mem-1000)/1000)*1000
|
||||||
@ -133,11 +133,11 @@ def load_model(model_name):
|
|||||||
suggestion -= 1000
|
suggestion -= 1000
|
||||||
suggestion = int(round(suggestion/1000))
|
suggestion = int(round(suggestion/1000))
|
||||||
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||||
settings.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
|
params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
|
||||||
if args.disk:
|
if args.disk:
|
||||||
settings.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'")
|
params.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'")
|
||||||
|
|
||||||
command = f"{command}(Path(f'models/{model_name}'), {','.join(set(settings))})"
|
command = f"{command}(Path(f'models/{model_name}'), {','.join(set(params))})"
|
||||||
model = eval(command)
|
model = eval(command)
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
@ -162,7 +162,7 @@ def load_model_wrapper(selected_model):
|
|||||||
model, tokenizer = load_model(model_name)
|
model, tokenizer = load_model(model_name)
|
||||||
|
|
||||||
def load_preset_values(preset_menu, return_dict=False):
|
def load_preset_values(preset_menu, return_dict=False):
|
||||||
settings = {
|
generate_params = {
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'temperature': 1,
|
'temperature': 1,
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
@ -180,14 +180,14 @@ def load_preset_values(preset_menu, return_dict=False):
|
|||||||
for i in preset.split(','):
|
for i in preset.split(','):
|
||||||
i = i.strip().split('=')
|
i = i.strip().split('=')
|
||||||
if len(i) == 2 and i[0].strip() != 'tokens':
|
if len(i) == 2 and i[0].strip() != 'tokens':
|
||||||
settings[i[0].strip()] = eval(i[1].strip())
|
generate_params[i[0].strip()] = eval(i[1].strip())
|
||||||
|
|
||||||
settings['temperature'] = min(1.99, settings['temperature'])
|
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||||
|
|
||||||
if return_dict:
|
if return_dict:
|
||||||
return settings
|
return generate_params
|
||||||
else:
|
else:
|
||||||
return settings['do_sample'], settings['temperature'], settings['top_p'], settings['typical_p'], settings['repetition_penalty'], settings['top_k'], settings['min_length'], settings['no_repeat_ngram_size'], settings['num_beams'], settings['length_penalty'], settings['early_stopping']
|
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
@ -365,7 +365,7 @@ def create_extensions_block():
|
|||||||
btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
|
btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
|
||||||
|
|
||||||
def create_settings_menus():
|
def create_settings_menus():
|
||||||
defaults = load_preset_values(settings[f'preset{suffix}'], return_dict=True)
|
generate_params = load_preset_values(settings[f'preset{suffix}'], return_dict=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -380,23 +380,23 @@ def create_settings_menus():
|
|||||||
with gr.Accordion("Custom generation parameters", open=False):
|
with gr.Accordion("Custom generation parameters", open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
do_sample = gr.Checkbox(value=defaults['do_sample'], label="do_sample")
|
do_sample = gr.Checkbox(value=generate_params['do_sample'], label="do_sample")
|
||||||
temperature = gr.Slider(0.01, 1.99, value=defaults['temperature'], step=0.01, label="temperature")
|
temperature = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label="temperature")
|
||||||
top_p = gr.Slider(0.0,1.0,value=defaults['top_p'],step=0.01,label="top_p")
|
top_p = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label="top_p")
|
||||||
typical_p = gr.Slider(0.0,1.0,value=defaults['typical_p'],step=0.01,label="typical_p")
|
typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
repetition_penalty = gr.Slider(1.0,4.99,value=defaults['repetition_penalty'],step=0.01,label="repetition_penalty")
|
repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty")
|
||||||
top_k = gr.Slider(0,200,value=defaults['top_k'],step=1,label="top_k")
|
top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k")
|
||||||
no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=defaults["no_repeat_ngram_size"], label="no_repeat_ngram_size")
|
no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size")
|
||||||
|
|
||||||
gr.Markdown("Special parameters (only use them if you really need them):")
|
gr.Markdown("Special parameters (only use them if you really need them):")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
num_beams = gr.Slider(0, 20, step=1, value=defaults["num_beams"], label="num_beams")
|
num_beams = gr.Slider(0, 20, step=1, value=generate_params["num_beams"], label="num_beams")
|
||||||
length_penalty = gr.Slider(-5, 5, value=defaults["length_penalty"], label="length_penalty")
|
length_penalty = gr.Slider(-5, 5, value=generate_params["length_penalty"], label="length_penalty")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
min_length = gr.Slider(0, 2000, step=1, value=defaults["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
|
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
|
||||||
early_stopping = gr.Checkbox(value=defaults["early_stopping"], label="early_stopping")
|
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
|
||||||
|
|
||||||
model_menu.change(load_model_wrapper, [model_menu], [])
|
model_menu.change(load_model_wrapper, [model_menu], [])
|
||||||
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping])
|
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping])
|
||||||
@ -737,10 +737,9 @@ loaded_preset = None
|
|||||||
default_text = settings['prompt_gpt4chan'] if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) else settings['prompt']
|
default_text = settings['prompt_gpt4chan'] if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) else settings['prompt']
|
||||||
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
|
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
|
||||||
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
|
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
|
||||||
|
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
|
||||||
buttons = {}
|
buttons = {}
|
||||||
gen_events = []
|
gen_events = []
|
||||||
|
|
||||||
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
|
|
||||||
history = {'internal': [], 'visible': []}
|
history = {'internal': [], 'visible': []}
|
||||||
character = None
|
character = None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user