Adapt to the new model names

This commit is contained in:
oobabooga 2023-03-29 21:47:36 -03:00
parent 0345e04249
commit 1cb9246160
6 changed files with 18 additions and 25 deletions

View File

@ -51,11 +51,12 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
def load_quantized(model_name): def load_quantized(model_name):
if not shared.args.model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name
if model_name.lower().startswith(('llama', 'alpaca')): name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca'])):
model_type = 'llama' model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')): elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt' model_type = 'opt'
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')): elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
model_type = 'gptj' model_type = 'gptj'
else: else:
print("Can't determine model type from model name. Please specify it manually using --model_type " print("Can't determine model type from model name. Please specify it manually using --model_type "

View File

@ -41,7 +41,7 @@ def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
shared.is_RWKV = model_name.lower().startswith('rwkv-') shared.is_RWKV = 'rwkv-' in model_name.lower()
# Default settings # Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
@ -159,7 +159,7 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Loading the tokenizer # Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))

View File

@ -37,10 +37,6 @@ settings = {
'chat_generation_attempts': 1, 'chat_generation_attempts': 1,
'chat_generation_attempts_min': 1, 'chat_generation_attempts_min': 1,
'chat_generation_attempts_max': 5, 'chat_generation_attempts_max': 5,
'name1_pygmalion': 'You',
'name2_pygmalion': 'Kawaii',
'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
'stop_at_newline_pygmalion': False,
'default_extensions': [], 'default_extensions': [],
'chat_default_extensions': ["gallery"], 'chat_default_extensions': ["gallery"],
'presets': { 'presets': {

View File

@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
def decode(output_ids): def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|> # Open Assistant relies on special tokens like <|endoftext|>
if re.match('(oasst|galactica)-*', shared.model_name.lower()): if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
return shared.tokenizer.decode(output_ids, skip_special_tokens=False) return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
else: else:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
@ -77,10 +77,10 @@ def fix_galactica(s):
def formatted_outputs(reply, model_name): def formatted_outputs(reply, model_name):
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
if model_name.lower().startswith('galactica'): if 'galactica' in model_name.lower():
reply = fix_galactica(reply) reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply) return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])):
reply = fix_gpt4chan(reply) reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else: else:

View File

@ -282,7 +282,6 @@ else:
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
title ='Text generation web UI' title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
def create_interface(): def create_interface():
@ -294,7 +293,7 @@ def create_interface():
if shared.args.chat or shared.args.cai_chat: if shared.args.chat or shared.args.cai_chat:
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
if shared.args.cai_chat: if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
else: else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
@ -314,9 +313,9 @@ def create_interface():
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context')
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@ -354,7 +353,7 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column(): with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset) create_settings_menus(default_preset)
@ -401,7 +400,7 @@ def create_interface():
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:

View File

@ -12,10 +12,6 @@
"chat_generation_attempts": 1, "chat_generation_attempts": 1,
"chat_generation_attempts_min": 1, "chat_generation_attempts_min": 1,
"chat_generation_attempts_max": 5, "chat_generation_attempts_max": 5,
"name1_pygmalion": "You",
"name2_pygmalion": "Kawaii",
"context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
"stop_at_newline_pygmalion": false,
"default_extensions": [], "default_extensions": [],
"chat_default_extensions": [ "chat_default_extensions": [
"gallery" "gallery"
@ -29,10 +25,11 @@
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
"(rosey|chip|joi)_.*_instruct.*": "User: \n", "(rosey|chip|joi)_.*_instruct.*": "User: \n",
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>" "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>",
"alpaca-*": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
}, },
"lora_prompts": { "lora_prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" "(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
} }
} }