mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-29 21:50:16 +01:00
Add universal LLaMA tokenizer support
This commit is contained in:
parent
32d47e4bad
commit
7bb9036ac9
@ -193,15 +193,25 @@ def load_model(model_name):
|
|||||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
elif type(model) is transformers.LlamaForCausalLM:
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
tokenizer = None
|
||||||
# Leaving this here until the LLaMA tokenizer gets figured out.
|
|
||||||
# For some people this fixes things, for others it causes an error.
|
# Try to load an universal LLaMA tokenizer
|
||||||
try:
|
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||||
tokenizer.eos_token_id = 2
|
if p.exists():
|
||||||
tokenizer.bos_token_id = 1
|
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||||
tokenizer.pad_token_id = 0
|
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||||
except:
|
break
|
||||||
pass
|
|
||||||
|
# Otherwise, load it from the model folder and hope that these
|
||||||
|
# are not outdated tokenizer files.
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
|
try:
|
||||||
|
tokenizer.eos_token_id = 2
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code)
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user