mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Revert "Remove non-HF ExLlamaV2 loader (#5431)"
This reverts commit cde000d478
.
This commit is contained in:
parent
8c35fefb3b
commit
2a1063eff5
@ -12,7 +12,7 @@ from modules.models import reload_model
|
|||||||
def add_lora_to_model(lora_names):
|
def add_lora_to_model(lora_names):
|
||||||
if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
|
if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
|
||||||
add_lora_autogptq(lora_names)
|
add_lora_autogptq(lora_names)
|
||||||
elif shared.model.__class__.__name__ == 'Exllamav2HF' or shared.args.loader == 'ExLlamav2_HF':
|
elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader in ['ExLlamav2', 'ExLlamav2_HF']:
|
||||||
add_lora_exllamav2(lora_names)
|
add_lora_exllamav2(lora_names)
|
||||||
else:
|
else:
|
||||||
add_lora_transformers(lora_names)
|
add_lora_transformers(lora_names)
|
||||||
@ -39,7 +39,11 @@ def add_lora_exllamav2(lora_names):
|
|||||||
shared.model.loras = []
|
shared.model.loras = []
|
||||||
for lora_name in lora_names:
|
for lora_name in lora_names:
|
||||||
lora_path = get_lora_path(lora_name)
|
lora_path = get_lora_path(lora_name)
|
||||||
|
if shared.model.__class__.__name__ == 'Exllamav2Model':
|
||||||
|
lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
|
||||||
|
else:
|
||||||
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
|
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
|
||||||
|
|
||||||
shared.model.loras.append(lora)
|
shared.model.loras.append(lora)
|
||||||
|
|
||||||
shared.lora_names = lora_names
|
shared.lora_names = lora_names
|
||||||
|
149
modules/exllamav2.py
Normal file
149
modules/exllamav2.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from exllamav2 import (
|
||||||
|
ExLlamaV2,
|
||||||
|
ExLlamaV2Cache,
|
||||||
|
ExLlamaV2Cache_8bit,
|
||||||
|
ExLlamaV2Config,
|
||||||
|
ExLlamaV2Tokenizer
|
||||||
|
)
|
||||||
|
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from modules.text_generation import get_max_prompt_length
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logger.warning(
|
||||||
|
'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
|
||||||
|
'to be a lot higher than it could be.\n'
|
||||||
|
'Try installing flash-attention following the instructions here: '
|
||||||
|
'https://github.com/Dao-AILab/flash-attention#installation-and-features'
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Failed to load flash-attention due to the following error:\n')
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
class Exllamav2Model:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, path_to_model):
|
||||||
|
|
||||||
|
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||||
|
|
||||||
|
config = ExLlamaV2Config()
|
||||||
|
config.model_dir = str(path_to_model)
|
||||||
|
config.prepare()
|
||||||
|
|
||||||
|
config.max_seq_len = shared.args.max_seq_len
|
||||||
|
config.scale_pos_emb = shared.args.compress_pos_emb
|
||||||
|
config.scale_alpha_value = shared.args.alpha_value
|
||||||
|
config.no_flash_attn = shared.args.no_flash_attn
|
||||||
|
config.num_experts_per_token = int(shared.args.num_experts_per_token)
|
||||||
|
|
||||||
|
model = ExLlamaV2(config)
|
||||||
|
|
||||||
|
split = None
|
||||||
|
if shared.args.gpu_split:
|
||||||
|
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||||
|
|
||||||
|
model.load(split)
|
||||||
|
|
||||||
|
tokenizer = ExLlamaV2Tokenizer(config)
|
||||||
|
if shared.args.cache_8bit:
|
||||||
|
cache = ExLlamaV2Cache_8bit(model)
|
||||||
|
else:
|
||||||
|
cache = ExLlamaV2Cache(model)
|
||||||
|
|
||||||
|
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
|
||||||
|
|
||||||
|
result = self()
|
||||||
|
result.model = model
|
||||||
|
result.cache = cache
|
||||||
|
result.tokenizer = tokenizer
|
||||||
|
result.generator = generator
|
||||||
|
result.loras = None
|
||||||
|
return result, result
|
||||||
|
|
||||||
|
def encode(self, string, **kwargs):
|
||||||
|
return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)
|
||||||
|
|
||||||
|
def decode(self, ids, **kwargs):
|
||||||
|
if isinstance(ids, list):
|
||||||
|
ids = torch.tensor([ids])
|
||||||
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
|
return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
|
||||||
|
|
||||||
|
def get_logits(self, token_ids, **kwargs):
|
||||||
|
self.cache.current_seq_len = 0
|
||||||
|
if token_ids.shape[-1] > 1:
|
||||||
|
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
|
||||||
|
|
||||||
|
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
|
||||||
|
|
||||||
|
def generate_with_streaming(self, prompt, state):
|
||||||
|
settings = ExLlamaV2Sampler.Settings()
|
||||||
|
|
||||||
|
settings.token_repetition_penalty = state['repetition_penalty']
|
||||||
|
settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
|
||||||
|
|
||||||
|
settings.token_frequency_penalty = state['frequency_penalty']
|
||||||
|
settings.token_presence_penalty = state['presence_penalty']
|
||||||
|
|
||||||
|
settings.temperature = state['temperature']
|
||||||
|
settings.top_k = state['top_k']
|
||||||
|
settings.top_p = state['top_p']
|
||||||
|
settings.top_a = state['top_a']
|
||||||
|
settings.min_p = state['min_p']
|
||||||
|
settings.tfs = state['tfs']
|
||||||
|
settings.typical = state['typical_p']
|
||||||
|
|
||||||
|
settings.temperature_last = state['temperature_last']
|
||||||
|
|
||||||
|
settings.mirostat = state['mirostat_mode'] == 2
|
||||||
|
settings.mirostat_tau = state['mirostat_tau']
|
||||||
|
settings.mirostat_eta = state['mirostat_eta']
|
||||||
|
|
||||||
|
if state['ban_eos_token']:
|
||||||
|
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||||
|
|
||||||
|
if state['custom_token_bans']:
|
||||||
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||||
|
if len(to_ban) > 0:
|
||||||
|
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||||
|
|
||||||
|
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
|
||||||
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
|
|
||||||
|
if state['auto_max_new_tokens']:
|
||||||
|
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||||
|
else:
|
||||||
|
max_new_tokens = state['max_new_tokens']
|
||||||
|
|
||||||
|
self.generator.begin_stream(ids, settings, loras=self.loras)
|
||||||
|
|
||||||
|
decoded_text = ''
|
||||||
|
for i in range(max_new_tokens):
|
||||||
|
chunk, eos, _ = self.generator.stream()
|
||||||
|
if eos or shared.stop_everything:
|
||||||
|
break
|
||||||
|
|
||||||
|
decoded_text += chunk
|
||||||
|
yield decoded_text
|
||||||
|
|
||||||
|
def generate(self, prompt, state):
|
||||||
|
output = ''
|
||||||
|
for output in self.generate_with_streaming(prompt, state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return output
|
@ -83,6 +83,16 @@ loaders_and_params = OrderedDict({
|
|||||||
'trust_remote_code',
|
'trust_remote_code',
|
||||||
'no_use_fast',
|
'no_use_fast',
|
||||||
],
|
],
|
||||||
|
'ExLlamav2': [
|
||||||
|
'gpu_split',
|
||||||
|
'max_seq_len',
|
||||||
|
'no_flash_attn',
|
||||||
|
'num_experts_per_token',
|
||||||
|
'cache_8bit',
|
||||||
|
'alpha_value',
|
||||||
|
'compress_pos_emb',
|
||||||
|
'exllamav2_info',
|
||||||
|
],
|
||||||
'AutoGPTQ': [
|
'AutoGPTQ': [
|
||||||
'triton',
|
'triton',
|
||||||
'no_inject_fused_attention',
|
'no_inject_fused_attention',
|
||||||
@ -197,6 +207,29 @@ loaders_samplers = {
|
|||||||
'AutoAWQ': transformers_samplers(),
|
'AutoAWQ': transformers_samplers(),
|
||||||
'QuIP#': transformers_samplers(),
|
'QuIP#': transformers_samplers(),
|
||||||
'HQQ': transformers_samplers(),
|
'HQQ': transformers_samplers(),
|
||||||
|
'ExLlamav2': {
|
||||||
|
'temperature',
|
||||||
|
'temperature_last',
|
||||||
|
'top_p',
|
||||||
|
'min_p',
|
||||||
|
'top_k',
|
||||||
|
'typical_p',
|
||||||
|
'tfs',
|
||||||
|
'top_a',
|
||||||
|
'repetition_penalty',
|
||||||
|
'presence_penalty',
|
||||||
|
'frequency_penalty',
|
||||||
|
'repetition_penalty_range',
|
||||||
|
'seed',
|
||||||
|
'mirostat_mode',
|
||||||
|
'mirostat_tau',
|
||||||
|
'mirostat_eta',
|
||||||
|
'ban_eos_token',
|
||||||
|
'add_bos_token',
|
||||||
|
'custom_token_bans',
|
||||||
|
'skip_special_tokens',
|
||||||
|
'auto_max_new_tokens',
|
||||||
|
},
|
||||||
'ExLlamav2_HF': {
|
'ExLlamav2_HF': {
|
||||||
'temperature',
|
'temperature',
|
||||||
'temperature_last',
|
'temperature_last',
|
||||||
|
@ -13,10 +13,11 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
|
|||||||
logger.error("No model is loaded! Select one in the Model tab.")
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
||||||
|
|
||||||
|
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
|
||||||
is_non_hf_llamacpp = shared.model.__class__.__name__ == 'LlamaCppModel'
|
is_non_hf_llamacpp = shared.model.__class__.__name__ == 'LlamaCppModel'
|
||||||
|
|
||||||
if use_samplers:
|
if use_samplers:
|
||||||
if is_non_hf_llamacpp:
|
if any([is_non_hf_exllamav2, is_non_hf_llamacpp]):
|
||||||
logger.error("Sampler hijacking is not supported non-Huggingface loaders.")
|
logger.error("Sampler hijacking is not supported non-Huggingface loaders.")
|
||||||
# sampling is all done in c for exllama, so it is really hard to hijack
|
# sampling is all done in c for exllama, so it is really hard to hijack
|
||||||
# it should be possible to hijack llamacpp sampler by hijacking all their sampling methods,
|
# it should be possible to hijack llamacpp sampler by hijacking all their sampling methods,
|
||||||
@ -30,7 +31,13 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
|
|||||||
|
|
||||||
scores = sampler_hijack.global_scores[-1]
|
scores = sampler_hijack.global_scores[-1]
|
||||||
else:
|
else:
|
||||||
if is_non_hf_llamacpp:
|
if is_non_hf_exllamav2:
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
tokens = shared.tokenizer.encode(prompt).to("xpu:0")
|
||||||
|
else:
|
||||||
|
tokens = shared.tokenizer.encode(prompt).cuda()
|
||||||
|
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||||
|
elif is_non_hf_llamacpp:
|
||||||
tokens = shared.tokenizer.encode(prompt)
|
tokens = shared.tokenizer.encode(prompt)
|
||||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||||
else:
|
else:
|
||||||
@ -38,7 +45,6 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
|
|||||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
|
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
|
||||||
else:
|
else:
|
||||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
||||||
|
|
||||||
output = shared.model(input_ids=tokens)
|
output = shared.model(input_ids=tokens)
|
||||||
scores = output['logits'][-1][-1]
|
scores = output['logits'][-1][-1]
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ def load_model(model_name, loader=None):
|
|||||||
'GPTQ-for-LLaMa': GPTQ_loader,
|
'GPTQ-for-LLaMa': GPTQ_loader,
|
||||||
'llama.cpp': llamacpp_loader,
|
'llama.cpp': llamacpp_loader,
|
||||||
'llamacpp_HF': llamacpp_HF_loader,
|
'llamacpp_HF': llamacpp_HF_loader,
|
||||||
|
'ExLlamav2': ExLlamav2_loader,
|
||||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||||
'ctransformers': ctransformers_loader,
|
'ctransformers': ctransformers_loader,
|
||||||
'AutoAWQ': AutoAWQ_loader,
|
'AutoAWQ': AutoAWQ_loader,
|
||||||
@ -375,6 +376,13 @@ def AutoGPTQ_loader(model_name):
|
|||||||
return modules.AutoGPTQ_loader.load_quantized(model_name)
|
return modules.AutoGPTQ_loader.load_quantized(model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def ExLlamav2_loader(model_name):
|
||||||
|
from modules.exllamav2 import Exllamav2Model
|
||||||
|
|
||||||
|
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def ExLlamav2_HF_loader(model_name):
|
def ExLlamav2_HF_loader(model_name):
|
||||||
from modules.exllamav2_hf import Exllamav2HF
|
from modules.exllamav2_hf import Exllamav2HF
|
||||||
|
|
||||||
|
@ -141,8 +141,6 @@ def get_model_metadata(model):
|
|||||||
if re.match(pat.lower(), model.lower()):
|
if re.match(pat.lower(), model.lower()):
|
||||||
for k in settings[pat]:
|
for k in settings[pat]:
|
||||||
model_settings[k] = settings[pat][k]
|
model_settings[k] = settings[pat][k]
|
||||||
if k == 'loader' and settings[pat][k] == 'ExLlamav2':
|
|
||||||
model_settings[k] = 'ExLlamav2_HF'
|
|
||||||
|
|
||||||
return model_settings
|
return model_settings
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ group.add_argument('--chat-buttons', action='store_true', help='Show buttons on
|
|||||||
|
|
||||||
# Model loader
|
# Model loader
|
||||||
group = parser.add_argument_group('Model loader')
|
group = parser.add_argument_group('Model loader')
|
||||||
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ctransformers, QuIP#.')
|
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ctransformers, QuIP#.')
|
||||||
|
|
||||||
# Transformers/Accelerate
|
# Transformers/Accelerate
|
||||||
group = parser.add_argument_group('Transformers/Accelerate')
|
group = parser.add_argument_group('Transformers/Accelerate')
|
||||||
@ -132,11 +132,11 @@ group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload
|
|||||||
group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
|
group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
|
||||||
group.add_argument('--row_split', action='store_true', help='Split multi-gpu by row instead of layer. Faster on some cards.')
|
group.add_argument('--row_split', action='store_true', help='Split multi-gpu by row instead of layer. Faster on some cards.')
|
||||||
|
|
||||||
# ExLlamaV2
|
# ExLlama
|
||||||
group = parser.add_argument_group('ExLlamaV2')
|
group = parser.add_argument_group('ExLlama')
|
||||||
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
|
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
|
||||||
group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
|
group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
|
||||||
group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
||||||
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
||||||
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
|
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
|
||||||
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
|
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
|
||||||
@ -250,7 +250,11 @@ def fix_loader_name(name):
|
|||||||
return 'AutoGPTQ'
|
return 'AutoGPTQ'
|
||||||
elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
|
elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
|
||||||
return 'GPTQ-for-LLaMa'
|
return 'GPTQ-for-LLaMa'
|
||||||
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2', 'exllama', 'ex-llama', 'ex_llama', 'exlama', 'exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']:
|
elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']:
|
||||||
|
return 'ExLlama'
|
||||||
|
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']:
|
||||||
|
return 'ExLlamav2'
|
||||||
|
elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']:
|
||||||
return 'ExLlamav2_HF'
|
return 'ExLlamav2_HF'
|
||||||
elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']:
|
elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']:
|
||||||
return 'ctransformers'
|
return 'ctransformers'
|
||||||
|
@ -45,7 +45,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
yield ''
|
yield ''
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel']:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'CtransformersModel']:
|
||||||
generate_func = generate_reply_custom
|
generate_func = generate_reply_custom
|
||||||
else:
|
else:
|
||||||
generate_func = generate_reply_HF
|
generate_func = generate_reply_HF
|
||||||
@ -121,10 +121,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
if shared.tokenizer is None:
|
if shared.tokenizer is None:
|
||||||
raise ValueError('No tokenizer is loaded')
|
raise ValueError('No tokenizer is loaded')
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel']:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel', 'Exllamav2Model']:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
# The step below is necessary for llama.cpp, but may not be
|
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
|
||||||
# necessary for future loaders.
|
|
||||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||||
else:
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||||
@ -136,7 +135,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
input_ids = input_ids[:, -truncation_length:]
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'CtransformersModel'] or shared.args.cpu:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
elif shared.args.deepspeed:
|
elif shared.args.deepspeed:
|
||||||
return input_ids.to(device=local_rank)
|
return input_ids.to(device=local_rank)
|
||||||
|
@ -142,6 +142,7 @@ def create_ui():
|
|||||||
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel for GPTQ models.')
|
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel for GPTQ models.')
|
||||||
shared.gradio['disable_exllamav2'] = gr.Checkbox(label="disable_exllamav2", value=shared.args.disable_exllamav2, info='Disable ExLlamav2 kernel for GPTQ models.')
|
shared.gradio['disable_exllamav2'] = gr.Checkbox(label="disable_exllamav2", value=shared.args.disable_exllamav2, info='Disable ExLlamav2 kernel for GPTQ models.')
|
||||||
shared.gradio['gptq_for_llama_info'] = gr.Markdown('Legacy loader for compatibility with older GPUs. ExLlamav2_HF or AutoGPTQ are preferred for GPTQ models when supported.')
|
shared.gradio['gptq_for_llama_info'] = gr.Markdown('Legacy loader for compatibility with older GPUs. ExLlamav2_HF or AutoGPTQ are preferred for GPTQ models when supported.')
|
||||||
|
shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.")
|
||||||
shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to download a tokenizer.\n\nOption 1 (recommended): place your .gguf in a subfolder of models/ along with these 4 files: special_tokens_map.json, tokenizer_config.json, tokenizer.json, tokenizer.model.\n\nOption 2: download `oobabooga/llama-tokenizer` under "Download model or LoRA". That\'s a default Llama tokenizer that will work for some (but not all) models.')
|
shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to download a tokenizer.\n\nOption 1 (recommended): place your .gguf in a subfolder of models/ along with these 4 files: special_tokens_map.json, tokenizer_config.json, tokenizer.json, tokenizer.model.\n\nOption 2: download `oobabooga/llama-tokenizer` under "Download model or LoRA". That\'s a default Llama tokenizer that will work for some (but not all) models.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
Loading…
Reference in New Issue
Block a user