Better variable names

This commit is contained in:
oobabooga 2023-02-08 00:19:20 -03:00
parent fc0493d885
commit 6be571cff7

View File

@ -115,17 +115,17 @@ def load_model(model_name):
# Custom # Custom
else: else:
command = "AutoModelForCausalLM.from_pretrained" command = "AutoModelForCausalLM.from_pretrained"
settings = ["low_cpu_mem_usage=True"] params = ["low_cpu_mem_usage=True"]
if args.cpu: if args.cpu:
settings.append("low_cpu_mem_usage=True") params.append("low_cpu_mem_usage=True")
settings.append("torch_dtype=torch.float32") params.append("torch_dtype=torch.float32")
else: else:
settings.append("device_map='auto'") params.append("device_map='auto'")
settings.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.float16") params.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.float16")
if args.gpu_memory: if args.gpu_memory:
settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") params.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit: elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit:
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
suggestion = round((total_mem-1000)/1000)*1000 suggestion = round((total_mem-1000)/1000)*1000
@ -133,11 +133,11 @@ def load_model(model_name):
suggestion -= 1000 suggestion -= 1000
suggestion = int(round(suggestion/1000)) suggestion = int(round(suggestion/1000))
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
settings.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
if args.disk: if args.disk:
settings.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'") params.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'")
command = f"{command}(Path(f'models/{model_name}'), {','.join(set(settings))})" command = f"{command}(Path(f'models/{model_name}'), {','.join(set(params))})"
model = eval(command) model = eval(command)
# Loading the tokenizer # Loading the tokenizer
@ -162,7 +162,7 @@ def load_model_wrapper(selected_model):
model, tokenizer = load_model(model_name) model, tokenizer = load_model(model_name)
def load_preset_values(preset_menu, return_dict=False): def load_preset_values(preset_menu, return_dict=False):
settings = { generate_params = {
'do_sample': True, 'do_sample': True,
'temperature': 1, 'temperature': 1,
'top_p': 1, 'top_p': 1,
@ -180,14 +180,14 @@ def load_preset_values(preset_menu, return_dict=False):
for i in preset.split(','): for i in preset.split(','):
i = i.strip().split('=') i = i.strip().split('=')
if len(i) == 2 and i[0].strip() != 'tokens': if len(i) == 2 and i[0].strip() != 'tokens':
settings[i[0].strip()] = eval(i[1].strip()) generate_params[i[0].strip()] = eval(i[1].strip())
settings['temperature'] = min(1.99, settings['temperature']) generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict: if return_dict:
return settings return generate_params
else: else:
return settings['do_sample'], settings['temperature'], settings['top_p'], settings['typical_p'], settings['repetition_penalty'], settings['top_k'], settings['min_length'], settings['no_repeat_ngram_size'], settings['num_beams'], settings['length_penalty'], settings['early_stopping'] return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['length_penalty'], generate_params['early_stopping']
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
@ -365,7 +365,7 @@ def create_extensions_block():
btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
def create_settings_menus(): def create_settings_menus():
defaults = load_preset_values(settings[f'preset{suffix}'], return_dict=True) generate_params = load_preset_values(settings[f'preset{suffix}'], return_dict=True)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -380,23 +380,23 @@ def create_settings_menus():
with gr.Accordion("Custom generation parameters", open=False): with gr.Accordion("Custom generation parameters", open=False):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
do_sample = gr.Checkbox(value=defaults['do_sample'], label="do_sample") do_sample = gr.Checkbox(value=generate_params['do_sample'], label="do_sample")
temperature = gr.Slider(0.01, 1.99, value=defaults['temperature'], step=0.01, label="temperature") temperature = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label="temperature")
top_p = gr.Slider(0.0,1.0,value=defaults['top_p'],step=0.01,label="top_p") top_p = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label="top_p")
typical_p = gr.Slider(0.0,1.0,value=defaults['typical_p'],step=0.01,label="typical_p") typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p")
with gr.Column(): with gr.Column():
repetition_penalty = gr.Slider(1.0,4.99,value=defaults['repetition_penalty'],step=0.01,label="repetition_penalty") repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty")
top_k = gr.Slider(0,200,value=defaults['top_k'],step=1,label="top_k") top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k")
no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=defaults["no_repeat_ngram_size"], label="no_repeat_ngram_size") no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size")
gr.Markdown("Special parameters (only use them if you really need them):") gr.Markdown("Special parameters (only use them if you really need them):")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
num_beams = gr.Slider(0, 20, step=1, value=defaults["num_beams"], label="num_beams") num_beams = gr.Slider(0, 20, step=1, value=generate_params["num_beams"], label="num_beams")
length_penalty = gr.Slider(-5, 5, value=defaults["length_penalty"], label="length_penalty") length_penalty = gr.Slider(-5, 5, value=generate_params["length_penalty"], label="length_penalty")
with gr.Column(): with gr.Column():
min_length = gr.Slider(0, 2000, step=1, value=defaults["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream) min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
early_stopping = gr.Checkbox(value=defaults["early_stopping"], label="early_stopping") early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
model_menu.change(load_model_wrapper, [model_menu], []) model_menu.change(load_model_wrapper, [model_menu], [])
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping]) preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping])
@ -737,10 +737,9 @@ loaded_preset = None
default_text = settings['prompt_gpt4chan'] if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) else settings['prompt'] default_text = settings['prompt_gpt4chan'] if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) else settings['prompt']
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}" css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
buttons = {} buttons = {}
gen_events = [] gen_events = []
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
history = {'internal': [], 'visible': []} history = {'internal': [], 'visible': []}
character = None character = None