mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-11 21:10:40 +01:00
Merge pull request #2922 from Honkware/main
Load Salesforce Xgen Models
This commit is contained in:
commit
e0a50fb77a
@ -12,13 +12,7 @@ This guide will cover usage through the official `transformers` implementation.
|
||||
* Torrent: https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789
|
||||
* Direct download: https://huggingface.co/Neko-Institute-of-Science
|
||||
|
||||
⚠️ The tokenizers for the Torrent source above and also for many LLaMA fine-tunes available on Hugging Face may be outdated, so I recommend downloading the following universal LLaMA tokenizer:
|
||||
|
||||
```
|
||||
python download-model.py oobabooga/llama-tokenizer
|
||||
```
|
||||
|
||||
Once downloaded, it will be automatically applied to **every** `LlamaForCausalLM` model that you try to load.
|
||||
⚠️ The tokenizers for the Torrent source above and also for many LLaMA fine-tunes available on Hugging Face may be outdated, in particular the files called `tokenizer_config.json` and `special_tokens_map.json`. Here you can find those files: https://huggingface.co/oobabooga/llama-tokenizer
|
||||
|
||||
### Option 2: convert the weights yourself
|
||||
|
||||
|
@ -240,3 +240,6 @@ TheBloke_WizardLM-30B-GPTQ:
|
||||
truncation_length: 8192
|
||||
.*superhot-8k:
|
||||
truncation_length: 8192
|
||||
.*xgen.*-inst:
|
||||
truncation_length: 8192
|
||||
instruction_template: 'Vicuna-v0'
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@ -14,7 +15,6 @@ from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
LlamaTokenizer
|
||||
)
|
||||
|
||||
import modules.shared as shared
|
||||
@ -91,30 +91,31 @@ def load_model(model_name, loader=None):
|
||||
|
||||
def load_tokenizer(model_name, model):
|
||||
tokenizer = None
|
||||
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
|
||||
if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif model.__class__.__name__ in ['LlamaForCausalLM', 'LlamaGPTQForCausalLM', 'ExllamaHF']:
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
if not any(s in shared.model_name.lower() for s in ['llava', 'oasst']):
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
return tokenizer
|
||||
elif path_to_model.exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
path_to_model,
|
||||
trust_remote_code=shared.args.trust_remote_code,
|
||||
use_fast=False
|
||||
)
|
||||
|
||||
# Otherwise, load it from the model folder and hope that these
|
||||
# are not outdated tokenizer files.
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
|
||||
try:
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
|
||||
if path_to_model.exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
if tokenizer.__class__.__name__ == 'LlamaTokenizer':
|
||||
pairs = [
|
||||
['tokenizer_config.json', '516c6167c884793a738c440e29ccb80c15e1493ffc965affc69a1a8ddef4572a'],
|
||||
['special_tokens_map.json', 'ff3b4a612c4e447acb02d40071bddd989fe0da87eb5b7fe0dbadfc4f74de7531']
|
||||
]
|
||||
|
||||
for pair in pairs:
|
||||
p = path_to_model / pair[0]
|
||||
if p.exists():
|
||||
with open(p, "rb") as f:
|
||||
bytes = f.read()
|
||||
|
||||
file_hash = hashlib.sha256(bytes).hexdigest()
|
||||
if file_hash != pair[1]:
|
||||
logger.warning(f"{p} is different from the original LlamaTokenizer file. It is either customized or outdated.")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user