Add universal LLaMA tokenizer support

This commit is contained in:
oobabooga 2023-04-19 21:23:51 -03:00
parent 32d47e4bad
commit 7bb9036ac9

View File

@ -193,15 +193,25 @@ def load_model(model_name):
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM: elif type(model) is transformers.LlamaForCausalLM:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) tokenizer = None
# Leaving this here until the LLaMA tokenizer gets figured out.
# For some people this fixes things, for others it causes an error. # Try to load an universal LLaMA tokenizer
try: for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
tokenizer.eos_token_id = 2 if p.exists():
tokenizer.bos_token_id = 1 print(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer.pad_token_id = 0 tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
except: break
pass
# Otherwise, load it from the model folder and hope that these
# are not outdated tokenizer files.
if tokenizer is None:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except:
pass
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code)