mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 22:08:53 +01:00
parent
6430acadde
commit
98361af4d5
@ -143,6 +143,11 @@ loaders_and_params = OrderedDict({
|
||||
'no_mmap',
|
||||
'mlock'
|
||||
],
|
||||
'QuIP#': [
|
||||
'trust_remote_code',
|
||||
'no_use_fast',
|
||||
'no_flash_attn',
|
||||
]
|
||||
})
|
||||
|
||||
loaders_samplers = {
|
||||
@ -453,6 +458,43 @@ loaders_samplers = {
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'QuIP#': {
|
||||
'temperature',
|
||||
'temperature_last',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
'typical_p',
|
||||
'epsilon_cutoff',
|
||||
'eta_cutoff',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'repetition_penalty',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'repetition_penalty_range',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'min_length',
|
||||
'seed',
|
||||
'do_sample',
|
||||
'penalty_alpha',
|
||||
'num_beams',
|
||||
'length_penalty',
|
||||
'early_stopping',
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'grammar_file_row',
|
||||
'grammar_string',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'custom_token_bans',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
}
|
||||
|
||||
loaders_model_types = {
|
||||
|
@ -1,4 +1,5 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@ -23,6 +24,7 @@ import modules.shared as shared
|
||||
from modules import RoPE, llama_attn_hijack, sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
from modules.models_settings import get_model_metadata
|
||||
from modules.relative_imports import RelativeImport
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -69,6 +71,7 @@ def load_model(model_name, loader=None):
|
||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||
'ctransformers': ctransformers_loader,
|
||||
'AutoAWQ': AutoAWQ_loader,
|
||||
'QuIP#': QuipSharp_loader,
|
||||
}
|
||||
|
||||
metadata = get_model_metadata(model_name)
|
||||
@ -321,6 +324,37 @@ def AutoAWQ_loader(model_name):
|
||||
return model
|
||||
|
||||
|
||||
def QuipSharp_loader(model_name):
|
||||
try:
|
||||
with RelativeImport("repositories/quip-sharp"):
|
||||
from lib.utils.unsafe_import import model_from_hf_path
|
||||
except:
|
||||
logger.error(
|
||||
"\nQuIP# has not been found. It must be installed manually for now.\n"
|
||||
"For instructions on how to do that, please consult:\n"
|
||||
"https://github.com/oobabooga/text-generation-webui/pull/4803\n"
|
||||
)
|
||||
return None, None
|
||||
|
||||
# This fixes duplicate logging messages after the import above.
|
||||
handlers = logging.getLogger().handlers
|
||||
if len(handlers) > 1:
|
||||
logging.getLogger().removeHandler(handlers[1])
|
||||
|
||||
model_dir = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not all((model_dir / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']):
|
||||
logger.error(f"Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into {model_dir}: special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json.")
|
||||
return None, None
|
||||
|
||||
model, model_str = model_from_hf_path(
|
||||
model_dir,
|
||||
use_cuda_graph=False,
|
||||
use_flash_attn=not shared.args.no_flash_attn
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def GPTQ_loader(model_name):
|
||||
|
||||
# Monkey patch
|
||||
|
@ -33,14 +33,24 @@ def get_model_metadata(model):
|
||||
for k in settings[pat]:
|
||||
model_settings[k] = settings[pat][k]
|
||||
|
||||
|
||||
path = Path(f'{shared.args.model_dir}/{model}/config.json')
|
||||
if path.exists():
|
||||
hf_metadata = json.loads(open(path, 'r').read())
|
||||
else:
|
||||
hf_metadata = None
|
||||
|
||||
if 'loader' not in model_settings:
|
||||
loader = infer_loader(model, model_settings)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
if hf_metadata is not None and 'quip_params' in hf_metadata:
|
||||
model_settings['loader'] = 'QuIP#'
|
||||
else:
|
||||
loader = infer_loader(model, model_settings)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
|
||||
model_settings['loader'] = loader
|
||||
model_settings['loader'] = loader
|
||||
|
||||
# Read GGUF metadata
|
||||
# GGUF metadata
|
||||
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||
path = Path(f'{shared.args.model_dir}/{model}')
|
||||
if path.is_file():
|
||||
@ -57,9 +67,8 @@ def get_model_metadata(model):
|
||||
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
|
||||
|
||||
else:
|
||||
# Read transformers metadata
|
||||
path = Path(f'{shared.args.model_dir}/{model}/config.json')
|
||||
if path.exists():
|
||||
# Transformers metadata
|
||||
if hf_metadata is not None:
|
||||
metadata = json.loads(open(path, 'r').read())
|
||||
if 'max_position_embeddings' in metadata:
|
||||
model_settings['truncation_length'] = metadata['max_position_embeddings']
|
||||
|
@ -241,6 +241,8 @@ def fix_loader_name(name):
|
||||
return 'ctransformers'
|
||||
elif name in ['autoawq', 'awq', 'auto-awq']:
|
||||
return 'AutoAWQ'
|
||||
elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']:
|
||||
return 'QuIP#'
|
||||
|
||||
|
||||
def add_extension(name, last=False):
|
||||
|
Loading…
Reference in New Issue
Block a user