mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add universal LLaMA tokenizer support
This commit is contained in:
parent
32d47e4bad
commit
7bb9036ac9
@ -193,15 +193,25 @@ def load_model(model_name):
|
||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||
# For some people this fixes things, for others it causes an error.
|
||||
try:
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
pass
|
||||
tokenizer = None
|
||||
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
|
||||
# Otherwise, load it from the model folder and hope that these
|
||||
# are not outdated tokenizer files.
|
||||
if tokenizer is None:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||
try:
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user