Add ExLlamaV2 and ExLlamav2_HF loaders (#3881)

This commit is contained in:
oobabooga 2023-09-12 14:33:07 -03:00 committed by GitHub
parent a821928877
commit c2a309f56e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 295 additions and 5 deletions

View File

@ -210,7 +210,7 @@ llama-65b-gptq-3bit:
instruction_template: 'Alpaca' instruction_template: 'Alpaca'
.*llama-(2|v2): .*llama-(2|v2):
truncation_length: 4096 truncation_length: 4096
.*llama-(2|v2).*chat: .*llama(-?)(2|v2).*chat:
instruction_template: 'Llama-v2' instruction_template: 'Llama-v2'
.*newhope: .*newhope:
instruction_template: 'NewHope' instruction_template: 'NewHope'

102
modules/exllamav2.py Normal file
View File

@ -0,0 +1,102 @@
import random
from pathlib import Path
import torch
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler
from modules import shared
from modules.text_generation import get_max_prompt_length
class Exllamav2Model:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
config = ExLlamaV2Config()
config.model_dir = path_to_model
config.prepare()
config.max_seq_len = shared.args.max_seq_len
model = ExLlamaV2(config)
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
model.load(split)
tokenizer = ExLlamaV2Tokenizer(config)
cache = ExLlamaV2Cache(model)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
result = self()
result.model = model
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
return result, tokenizer
def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings()
settings.temperature = state['temperature']
settings.top_k = state['top_k']
settings.top_p = state['top_p']
settings.token_repetition_penalty = state['repetition_penalty']
settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
if state['ban_eos_token']:
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
ids = self.tokenizer.encode(prompt)
ids = ids[:, -get_max_prompt_length(state):]
initial_len = ids.shape[-1]
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1]
else:
max_new_tokens = state['max_new_tokens']
# _gen_begin_base
self.cache.current_seq_len = 0
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
has_leading_space = False
for i in range(max_new_tokens):
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu()
token, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random())
ids = torch.cat([ids, token], dim=1)
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
has_leading_space = True
decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0]
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break
def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):
pass
return output
def encode(self, string, **kwargs):
return self.tokenizer.encode(string)
def decode(self, string, **kwargs):
return self.tokenizer.decode(string)[0]

119
modules/exllamav2_hf.py Normal file
View File

@ -0,0 +1,119 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
class Exllamav2HF(PreTrainedModel):
def __init__(self, config: ExLlamaV2Config):
super().__init__(PretrainedConfig())
self.ex_config = config
self.ex_model = ExLlamaV2(config)
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load(split)
self.generation_config = GenerationConfig()
self.ex_cache = ExLlamaV2Cache(self.ex_model)
self.past_seq = None
if shared.args.cfg_cache:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
self.past_seq_negative = None
def _validate_model_class(self):
pass
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
pass
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {'input_ids': input_ids, **kwargs}
@property
def device(self) -> torch.device:
return torch.device(0)
def __call__(self, *args, **kwargs):
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
past_key_values = kwargs.get('past_key_values', None)
if len(args) > 0:
if not shared.args.cfg_cache:
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")
return
input_ids = args[0]
is_negative = True
past_seq = self.past_seq_negative
ex_cache = self.ex_cache_negative
else:
input_ids = kwargs['input_ids']
is_negative = False
past_seq = self.past_seq
ex_cache = self.ex_cache
seq = input_ids[0].tolist()
if is_negative and past_key_values is not None:
seq = past_key_values + seq
seq_tensor = torch.tensor(seq)
# Make the forward call
if labels is None:
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
ex_cache.current_seq_len = 0
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
# logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False)
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache)
if is_negative:
self.past_seq_negative = seq_tensor
else:
self.past_seq = seq_tensor
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
config = ExLlamaV2Config()
config.model_dir = pretrained_model_name_or_path
config.prepare()
config.max_seq_len = shared.args.max_seq_len
return Exllamav2HF(config)

View File

@ -42,6 +42,15 @@ loaders_and_params = OrderedDict({
'compress_pos_emb', 'compress_pos_emb',
'exllama_info', 'exllama_info',
], ],
'ExLlamav2': [
'gpu_split',
'max_seq_len',
],
'ExLlamav2_HF': [
'gpu_split',
'max_seq_len',
'cfg_cache',
],
'AutoGPTQ': [ 'AutoGPTQ': [
'triton', 'triton',
'no_inject_fused_attention', 'no_inject_fused_attention',
@ -180,6 +189,42 @@ loaders_samplers = {
'ban_eos_token', 'ban_eos_token',
'auto_max_new_tokens', 'auto_max_new_tokens',
}, },
'ExLlamav2': {
'temperature',
'top_p',
'top_k',
'repetition_penalty',
'repetition_penalty_range',
'seed',
'ban_eos_token',
'auto_max_new_tokens',
},
'ExLlamav2_HF': {
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
'AutoGPTQ': { 'AutoGPTQ': {
'temperature', 'temperature',
'top_p', 'top_p',

View File

@ -59,6 +59,8 @@ def load_model(model_name, loader=None):
'RWKV': RWKV_loader, 'RWKV': RWKV_loader,
'ExLlama': ExLlama_loader, 'ExLlama': ExLlama_loader,
'ExLlama_HF': ExLlama_HF_loader, 'ExLlama_HF': ExLlama_HF_loader,
'ExLlamav2': ExLlamav2_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ctransformers': ctransformers_loader, 'ctransformers': ctransformers_loader,
} }
@ -329,6 +331,19 @@ def ExLlama_HF_loader(model_name):
return ExllamaHF.from_pretrained(model_name) return ExllamaHF.from_pretrained(model_name)
def ExLlamav2_loader(model_name):
from modules.exllamav2 import Exllamav2Model
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
return model, tokenizer
def ExLlamav2_HF_loader(model_name):
from modules.exllamav2_hf import Exllamav2HF
return Exllamav2HF.from_pretrained(model_name)
def get_max_memory_dict(): def get_max_memory_dict():
max_memory = {} max_memory = {}
if shared.args.gpu_memory: if shared.args.gpu_memory:

View File

@ -219,6 +219,10 @@ def fix_loader_name(name):
return 'ExLlama' return 'ExLlama'
elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']: elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']:
return 'ExLlama_HF' return 'ExLlama_HF'
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2']:
return 'ExLlamav2'
elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf']:
return 'ExLlamav2_HF'
elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']: elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']:
return 'ctransformers' return 'ctransformers'

View File

@ -42,7 +42,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
yield '' yield ''
return return
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel']: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']:
generate_func = generate_reply_custom generate_func = generate_reply_custom
else: else:
generate_func = generate_reply_HF generate_func = generate_reply_HF
@ -106,9 +106,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel']: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']:
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids)) if shared.model.__class__.__name__ not in ['Exllamav2Model']:
input_ids = np.array(input_ids).reshape(1, len(input_ids))
else: else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
@ -120,7 +121,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None: if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:] input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel'] or shared.args.cpu: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu:
return input_ids return input_ids
elif shared.args.deepspeed: elif shared.args.deepspeed:
return input_ids.to(device=local_rank) return input_ids.to(device=local_rank)

View File

@ -8,7 +8,9 @@ accelerate==0.22.*
colorama colorama
datasets datasets
einops einops
exllamav2==0.0.0
markdown markdown
ninja
numpy==1.24 numpy==1.24
optimum==1.12.0 optimum==1.12.0
pandas pandas

View File

@ -8,7 +8,9 @@ accelerate==0.22.*
colorama colorama
datasets datasets
einops einops
exllamav2==0.0.0
markdown markdown
ninja
numpy==1.24 numpy==1.24
optimum==1.12.0 optimum==1.12.0
pandas pandas