mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Adapt to the new model names
This commit is contained in:
parent
0345e04249
commit
1cb9246160
@ -51,11 +51,12 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
def load_quantized(model_name):
|
def load_quantized(model_name):
|
||||||
if not shared.args.model_type:
|
if not shared.args.model_type:
|
||||||
# Try to determine model type from model name
|
# Try to determine model type from model name
|
||||||
if model_name.lower().startswith(('llama', 'alpaca')):
|
name = model_name.lower()
|
||||||
|
if any((k in name for k in ['llama', 'alpaca'])):
|
||||||
model_type = 'llama'
|
model_type = 'llama'
|
||||||
elif model_name.lower().startswith(('opt', 'galactica')):
|
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||||
model_type = 'opt'
|
model_type = 'opt'
|
||||||
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
|
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
|
||||||
model_type = 'gptj'
|
model_type = 'gptj'
|
||||||
else:
|
else:
|
||||||
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
||||||
|
@ -41,7 +41,7 @@ def load_model(model_name):
|
|||||||
print(f"Loading {model_name}...")
|
print(f"Loading {model_name}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.is_RWKV = model_name.lower().startswith('rwkv-')
|
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
||||||
|
|
||||||
# Default settings
|
# Default settings
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
|
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
|
||||||
@ -159,7 +159,7 @@ def load_model(model_name):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||||
|
@ -37,10 +37,6 @@ settings = {
|
|||||||
'chat_generation_attempts': 1,
|
'chat_generation_attempts': 1,
|
||||||
'chat_generation_attempts_min': 1,
|
'chat_generation_attempts_min': 1,
|
||||||
'chat_generation_attempts_max': 5,
|
'chat_generation_attempts_max': 5,
|
||||||
'name1_pygmalion': 'You',
|
|
||||||
'name2_pygmalion': 'Kawaii',
|
|
||||||
'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
|
|
||||||
'stop_at_newline_pygmalion': False,
|
|
||||||
'default_extensions': [],
|
'default_extensions': [],
|
||||||
'chat_default_extensions': ["gallery"],
|
'chat_default_extensions': ["gallery"],
|
||||||
'presets': {
|
'presets': {
|
||||||
|
@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||||||
|
|
||||||
def decode(output_ids):
|
def decode(output_ids):
|
||||||
# Open Assistant relies on special tokens like <|endoftext|>
|
# Open Assistant relies on special tokens like <|endoftext|>
|
||||||
if re.match('(oasst|galactica)-*', shared.model_name.lower()):
|
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||||
else:
|
else:
|
||||||
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
@ -77,10 +77,10 @@ def fix_galactica(s):
|
|||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
def formatted_outputs(reply, model_name):
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not (shared.args.chat or shared.args.cai_chat):
|
||||||
if model_name.lower().startswith('galactica'):
|
if 'galactica' in model_name.lower():
|
||||||
reply = fix_galactica(reply)
|
reply = fix_galactica(reply)
|
||||||
return reply, reply, generate_basic_html(reply)
|
return reply, reply, generate_basic_html(reply)
|
||||||
elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])):
|
||||||
reply = fix_gpt4chan(reply)
|
reply = fix_gpt4chan(reply)
|
||||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||||
else:
|
else:
|
||||||
|
13
server.py
13
server.py
@ -282,7 +282,6 @@ else:
|
|||||||
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||||
title ='Text generation web UI'
|
title ='Text generation web UI'
|
||||||
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
|
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
|
||||||
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
|
||||||
|
|
||||||
def create_interface():
|
def create_interface():
|
||||||
|
|
||||||
@ -294,7 +293,7 @@ def create_interface():
|
|||||||
if shared.args.chat or shared.args.cai_chat:
|
if shared.args.chat or shared.args.cai_chat:
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
if shared.args.cai_chat:
|
if shared.args.cai_chat:
|
||||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
|
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
|
||||||
else:
|
else:
|
||||||
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
|
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
|
||||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||||
@ -314,9 +313,9 @@ def create_interface():
|
|||||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||||
|
|
||||||
with gr.Tab("Character", elem_id="chat-settings"):
|
with gr.Tab("Character", elem_id="chat-settings"):
|
||||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
|
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
|
||||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name')
|
||||||
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
|
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
||||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
||||||
@ -354,7 +353,7 @@ def create_interface():
|
|||||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||||
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||||
|
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
@ -401,7 +400,7 @@ def create_interface():
|
|||||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||||
|
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
||||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
|
@ -12,10 +12,6 @@
|
|||||||
"chat_generation_attempts": 1,
|
"chat_generation_attempts": 1,
|
||||||
"chat_generation_attempts_min": 1,
|
"chat_generation_attempts_min": 1,
|
||||||
"chat_generation_attempts_max": 5,
|
"chat_generation_attempts_max": 5,
|
||||||
"name1_pygmalion": "You",
|
|
||||||
"name2_pygmalion": "Kawaii",
|
|
||||||
"context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
|
|
||||||
"stop_at_newline_pygmalion": false,
|
|
||||||
"default_extensions": [],
|
"default_extensions": [],
|
||||||
"chat_default_extensions": [
|
"chat_default_extensions": [
|
||||||
"gallery"
|
"gallery"
|
||||||
@ -29,10 +25,11 @@
|
|||||||
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
||||||
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
|
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
|
||||||
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
|
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
|
||||||
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
|
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>",
|
||||||
|
"alpaca-*": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
||||||
},
|
},
|
||||||
"lora_prompts": {
|
"lora_prompts": {
|
||||||
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
||||||
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
"(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user