mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-01 11:54:04 +01:00
Bump transformers (16-bit llama must be reconverted/redownloaded)
This commit is contained in:
parent
5f4f38ca5d
commit
113f94b61e
@ -10,7 +10,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||||
BitsAndBytesConfig)
|
BitsAndBytesConfig, LlamaTokenizer)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
@ -172,6 +172,8 @@ def load_model(model_name):
|
|||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||||
tokenizer.truncation_side = 'left'
|
tokenizer.truncation_side = 'left'
|
||||||
|
@ -28,6 +28,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||||||
return input_ids
|
return input_ids
|
||||||
else:
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||||
|
|
||||||
|
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||||
|
input_ids = input_ids[:,1:]
|
||||||
|
|
||||||
if shared.args.cpu:
|
if shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
elif shared.args.flexgen:
|
elif shared.args.flexgen:
|
||||||
|
@ -13,4 +13,4 @@ safetensors==0.3.0
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
pyyaml
|
pyyaml
|
||||||
tqdm
|
tqdm
|
||||||
git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0
|
git+https://github.com/huggingface/transformers
|
||||||
|
Loading…
Reference in New Issue
Block a user