mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Use oasst tokenizer instead of universal tokenizer
This commit is contained in:
parent
b6ff138084
commit
97a6a50d98
@ -24,7 +24,6 @@ llama-[0-9]*b-4bit$:
|
||||
.*(oasst-sft-1-pythia-12b|oasst-sft-6-llama-30b):
|
||||
mode: 'instruct'
|
||||
instruction_template: 'Open Assistant'
|
||||
custom_stopping_strings: '"<|endoftext|>"'
|
||||
.*vicuna:
|
||||
mode: 'instruct'
|
||||
instruction_template: 'Vicuna-v0'
|
||||
|
@ -54,6 +54,8 @@ def find_model_type(model_name):
|
||||
return 'galactica'
|
||||
elif 'llava' in model_name_lower:
|
||||
return 'llava'
|
||||
elif 'oasst' in model_name_lower:
|
||||
return 'oasst'
|
||||
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
@ -227,7 +229,7 @@ def load_model(model_name):
|
||||
tokenizer = None
|
||||
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
if shared.model_type != 'llava':
|
||||
if shared.model_type not in ['llava', 'oasst']:
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
|
Loading…
Reference in New Issue
Block a user