mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Adapt to the new model names
This commit is contained in:
parent
0345e04249
commit
1cb9246160
@ -51,11 +51,12 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
||||
def load_quantized(model_name):
|
||||
if not shared.args.model_type:
|
||||
# Try to determine model type from model name
|
||||
if model_name.lower().startswith(('llama', 'alpaca')):
|
||||
name = model_name.lower()
|
||||
if any((k in name for k in ['llama', 'alpaca'])):
|
||||
model_type = 'llama'
|
||||
elif model_name.lower().startswith(('opt', 'galactica')):
|
||||
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||
model_type = 'opt'
|
||||
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
|
||||
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
|
||||
model_type = 'gptj'
|
||||
else:
|
||||
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
||||
|
@ -41,7 +41,7 @@ def load_model(model_name):
|
||||
print(f"Loading {model_name}...")
|
||||
t0 = time.time()
|
||||
|
||||
shared.is_RWKV = model_name.lower().startswith('rwkv-')
|
||||
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
||||
|
||||
# Default settings
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
|
||||
@ -159,7 +159,7 @@ def load_model(model_name):
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||
|
||||
# Loading the tokenizer
|
||||
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||
|
@ -37,10 +37,6 @@ settings = {
|
||||
'chat_generation_attempts': 1,
|
||||
'chat_generation_attempts_min': 1,
|
||||
'chat_generation_attempts_max': 5,
|
||||
'name1_pygmalion': 'You',
|
||||
'name2_pygmalion': 'Kawaii',
|
||||
'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
|
||||
'stop_at_newline_pygmalion': False,
|
||||
'default_extensions': [],
|
||||
'chat_default_extensions': ["gallery"],
|
||||
'presets': {
|
||||
|
@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
|
||||
def decode(output_ids):
|
||||
# Open Assistant relies on special tokens like <|endoftext|>
|
||||
if re.match('(oasst|galactica)-*', shared.model_name.lower()):
|
||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
else:
|
||||
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
@ -77,10 +77,10 @@ def fix_galactica(s):
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
if model_name.lower().startswith('galactica'):
|
||||
if 'galactica' in model_name.lower():
|
||||
reply = fix_galactica(reply)
|
||||
return reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
||||
elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
else:
|
||||
|
13
server.py
13
server.py
@ -282,7 +282,6 @@ else:
|
||||
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||
title ='Text generation web UI'
|
||||
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
|
||||
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
||||
|
||||
def create_interface():
|
||||
|
||||
@ -294,7 +293,7 @@ def create_interface():
|
||||
if shared.args.chat or shared.args.cai_chat:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
if shared.args.cai_chat:
|
||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
|
||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
|
||||
else:
|
||||
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
@ -314,9 +313,9 @@ def create_interface():
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
with gr.Tab("Character", elem_id="chat-settings"):
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context')
|
||||
with gr.Row():
|
||||
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
||||
@ -354,7 +353,7 @@ def create_interface():
|
||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||
with gr.Column():
|
||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
||||
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
@ -401,7 +400,7 @@ def create_interface():
|
||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
|
@ -12,10 +12,6 @@
|
||||
"chat_generation_attempts": 1,
|
||||
"chat_generation_attempts_min": 1,
|
||||
"chat_generation_attempts_max": 5,
|
||||
"name1_pygmalion": "You",
|
||||
"name2_pygmalion": "Kawaii",
|
||||
"context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
|
||||
"stop_at_newline_pygmalion": false,
|
||||
"default_extensions": [],
|
||||
"chat_default_extensions": [
|
||||
"gallery"
|
||||
@ -29,10 +25,11 @@
|
||||
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
||||
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
|
||||
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
|
||||
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
|
||||
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>",
|
||||
"alpaca-*": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
||||
},
|
||||
"lora_prompts": {
|
||||
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
||||
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
||||
"(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user