mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Read GGUF metadata (#3873)
This commit is contained in:
parent
39f4800d94
commit
9331ab4798
@ -7,10 +7,7 @@ from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (
|
||||
get_model_settings_from_yamls,
|
||||
update_model_parameters
|
||||
)
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.text_generation import (
|
||||
encode,
|
||||
generate_reply,
|
||||
@ -132,7 +129,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
shared.model_name = model_name
|
||||
unload_model()
|
||||
|
||||
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||
model_settings = get_model_metadata(shared.model_name)
|
||||
shared.settings.update(model_settings)
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
|
@ -1,11 +1,9 @@
|
||||
from modules import shared
|
||||
from modules.utils import get_available_models
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (get_model_settings_from_yamls,
|
||||
update_model_parameters)
|
||||
|
||||
from extensions.openai.embeddings import get_embeddings_model_name
|
||||
from extensions.openai.errors import *
|
||||
from modules import shared
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.utils import get_available_models
|
||||
|
||||
|
||||
def get_current_model_list() -> list:
|
||||
@ -33,7 +31,7 @@ def load_model(model_name: str) -> dict:
|
||||
shared.model_name = model_name
|
||||
unload_model()
|
||||
|
||||
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||
model_settings = get_model_metadata(shared.model_name)
|
||||
shared.settings.update(model_settings)
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
|
@ -8,10 +8,7 @@ from tqdm import tqdm
|
||||
|
||||
from modules import shared
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (
|
||||
get_model_settings_from_yamls,
|
||||
update_model_parameters
|
||||
)
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.text_generation import encode
|
||||
|
||||
|
||||
@ -69,7 +66,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||
if model != 'current model':
|
||||
try:
|
||||
yield cumulative_log + f"Loading {model}...\n\n"
|
||||
model_settings = get_model_settings_from_yamls(model)
|
||||
model_settings = get_model_metadata(model)
|
||||
shared.settings.update(model_settings) # hijacking the interface defaults
|
||||
update_model_parameters(model_settings) # hijacking the command-line arguments
|
||||
shared.model_name = model
|
||||
|
84
modules/metadata_gguf.py
Normal file
84
modules/metadata_gguf.py
Normal file
@ -0,0 +1,84 @@
|
||||
import struct
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class GGUFValueType(IntEnum):
|
||||
UINT8 = 0
|
||||
INT8 = 1
|
||||
UINT16 = 2
|
||||
INT16 = 3
|
||||
UINT32 = 4
|
||||
INT32 = 5
|
||||
FLOAT32 = 6
|
||||
BOOL = 7
|
||||
STRING = 8
|
||||
ARRAY = 9
|
||||
UINT64 = 10
|
||||
INT64 = 11
|
||||
FLOAT64 = 12
|
||||
|
||||
|
||||
_simple_value_packing = {
|
||||
GGUFValueType.UINT8: "<B",
|
||||
GGUFValueType.INT8: "<b",
|
||||
GGUFValueType.UINT16: "<H",
|
||||
GGUFValueType.INT16: "<h",
|
||||
GGUFValueType.UINT32: "<I",
|
||||
GGUFValueType.INT32: "<i",
|
||||
GGUFValueType.FLOAT32: "<f",
|
||||
GGUFValueType.UINT64: "<Q",
|
||||
GGUFValueType.INT64: "<q",
|
||||
GGUFValueType.FLOAT64: "<d",
|
||||
GGUFValueType.BOOL: "?",
|
||||
}
|
||||
|
||||
value_type_info = {
|
||||
GGUFValueType.UINT8: 1,
|
||||
GGUFValueType.INT8: 1,
|
||||
GGUFValueType.UINT16: 2,
|
||||
GGUFValueType.INT16: 2,
|
||||
GGUFValueType.UINT32: 4,
|
||||
GGUFValueType.INT32: 4,
|
||||
GGUFValueType.FLOAT32: 4,
|
||||
GGUFValueType.UINT64: 8,
|
||||
GGUFValueType.INT64: 8,
|
||||
GGUFValueType.FLOAT64: 8,
|
||||
GGUFValueType.BOOL: 1,
|
||||
}
|
||||
|
||||
|
||||
def get_single(value_type, file):
|
||||
if value_type == GGUFValueType.STRING:
|
||||
value_length = struct.unpack("<Q", file.read(8))[0]
|
||||
value = file.read(value_length).decode('utf-8')
|
||||
else:
|
||||
type_str = _simple_value_packing.get(value_type)
|
||||
bytes_length = value_type_info.get(value_type)
|
||||
value = struct.unpack(type_str, file.read(bytes_length))[0]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def load_metadata(fname):
|
||||
metadata = {}
|
||||
with open(fname, 'rb') as file:
|
||||
GGUF_MAGIC = struct.unpack("<I", file.read(4))[0]
|
||||
GGUF_VERSION = struct.unpack("<I", file.read(4))[0]
|
||||
ti_data_count = struct.unpack("<Q", file.read(8))[0]
|
||||
kv_data_count = struct.unpack("<Q", file.read(8))[0]
|
||||
|
||||
for i in range(kv_data_count):
|
||||
key_length = struct.unpack("<Q", file.read(8))[0]
|
||||
key = file.read(key_length)
|
||||
|
||||
value_type = GGUFValueType(struct.unpack("<I", file.read(4))[0])
|
||||
if value_type == GGUFValueType.ARRAY:
|
||||
ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0])
|
||||
length = struct.unpack("<Q", file.read(8))[0]
|
||||
for j in range(length):
|
||||
_ = get_single(ltype, file)
|
||||
else:
|
||||
value = get_single(value_type, file)
|
||||
metadata[key.decode()] = value
|
||||
|
||||
return metadata
|
@ -18,9 +18,9 @@ from transformers import (
|
||||
)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack, RoPE, sampler_hijack
|
||||
from modules import RoPE, llama_attn_hijack, sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
from modules.models_settings import infer_loader
|
||||
from modules.models_settings import get_model_metadata
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -62,15 +62,11 @@ def load_model(model_name, loader=None):
|
||||
'ctransformers': ctransformers_loader,
|
||||
}
|
||||
|
||||
p = Path(model_name)
|
||||
if p.exists():
|
||||
model_name = p.parts[-1]
|
||||
|
||||
if loader is None:
|
||||
if shared.args.loader is not None:
|
||||
loader = shared.args.loader
|
||||
else:
|
||||
loader = infer_loader(model_name)
|
||||
loader = get_model_metadata(model_name)['loader']
|
||||
if loader is None:
|
||||
logger.error('The path to the model does not exist. Exiting.')
|
||||
return None, None
|
||||
|
@ -3,23 +3,57 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from modules import loaders, shared, ui
|
||||
from modules import loaders, metadata_gguf, shared, ui
|
||||
|
||||
|
||||
def get_model_settings_from_yamls(model):
|
||||
settings = shared.model_config
|
||||
def get_fallback_settings():
|
||||
return {
|
||||
'wbits': 'None',
|
||||
'model_type': 'None',
|
||||
'groupsize': 'None',
|
||||
'pre_layer': 0,
|
||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
|
||||
'truncation_length': shared.settings['truncation_length'],
|
||||
'n_ctx': 2048,
|
||||
'rope_freq_base': 0,
|
||||
}
|
||||
|
||||
|
||||
def get_model_metadata(model):
|
||||
model_settings = {}
|
||||
|
||||
# Get settings from models/config.yaml and models/config-user.yaml
|
||||
settings = shared.model_config
|
||||
for pat in settings:
|
||||
if re.match(pat.lower(), model.lower()):
|
||||
for k in settings[pat]:
|
||||
model_settings[k] = settings[pat][k]
|
||||
|
||||
if 'loader' not in model_settings:
|
||||
loader = infer_loader(model, model_settings)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
|
||||
model_settings['loader'] = loader
|
||||
|
||||
# Read GGUF metadata
|
||||
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||
path = Path(f'{shared.args.model_dir}/{model}')
|
||||
if path.is_file():
|
||||
model_file = path
|
||||
else:
|
||||
model_file = list(path.glob('*.gguf'))[0]
|
||||
|
||||
metadata = metadata_gguf.load_metadata(model_file)
|
||||
if 'llama.context_length' in metadata:
|
||||
model_settings['n_ctx'] = metadata['llama.context_length']
|
||||
|
||||
return model_settings
|
||||
|
||||
|
||||
def infer_loader(model_name):
|
||||
def infer_loader(model_name, model_settings):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
model_settings = get_model_settings_from_yamls(model_name)
|
||||
if not path_to_model.exists():
|
||||
loader = None
|
||||
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
|
||||
@ -85,11 +119,9 @@ def update_model_parameters(state, initial=False):
|
||||
|
||||
# UI: update the state variable with the model settings
|
||||
def apply_model_settings_to_state(model, state):
|
||||
model_settings = get_model_settings_from_yamls(model)
|
||||
if 'loader' not in model_settings:
|
||||
loader = infer_loader(model)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
model_settings = get_model_metadata(model)
|
||||
if 'loader' in model_settings:
|
||||
loader = model_settings.pop('loader')
|
||||
|
||||
# If the user is using an alternative loader for the same model type, let them keep using it
|
||||
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']) and not (loader == 'llama.cpp' and state['loader'] in ['llamacpp_HF', 'ctransformers']):
|
||||
|
@ -15,7 +15,7 @@ from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (
|
||||
apply_model_settings_to_state,
|
||||
get_model_settings_from_yamls,
|
||||
get_model_metadata,
|
||||
save_model_settings,
|
||||
update_model_parameters
|
||||
)
|
||||
@ -196,7 +196,7 @@ def load_model_wrapper(selected_model, loader, autoload=False):
|
||||
if shared.model is not None:
|
||||
output = f"Successfully loaded `{selected_model}`."
|
||||
|
||||
settings = get_model_settings_from_yamls(selected_model)
|
||||
settings = get_model_metadata(selected_model)
|
||||
if 'instruction_template' in settings:
|
||||
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
|
||||
|
||||
|
30
server.py
30
server.py
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from modules.logging_colors import logger
|
||||
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
|
||||
from modules.logging_colors import logger
|
||||
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||
@ -12,6 +12,7 @@ with RequestBlocker():
|
||||
import gradio as gr
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
||||
|
||||
import json
|
||||
@ -37,13 +38,14 @@ from modules import (
|
||||
ui_notebook,
|
||||
ui_parameters,
|
||||
ui_session,
|
||||
utils,
|
||||
utils
|
||||
)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model
|
||||
from modules.models_settings import (
|
||||
get_model_settings_from_yamls,
|
||||
get_fallback_settings,
|
||||
get_model_metadata,
|
||||
update_model_parameters
|
||||
)
|
||||
from modules.utils import gradio
|
||||
@ -169,17 +171,7 @@ if __name__ == "__main__":
|
||||
shared.settings.update(new_settings)
|
||||
|
||||
# Fallback settings for models
|
||||
shared.model_config['.*'] = {
|
||||
'wbits': 'None',
|
||||
'model_type': 'None',
|
||||
'groupsize': 'None',
|
||||
'pre_layer': 0,
|
||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
|
||||
'truncation_length': shared.settings['truncation_length'],
|
||||
'rope_freq_base': 0,
|
||||
}
|
||||
|
||||
shared.model_config['.*'] = get_fallback_settings()
|
||||
shared.model_config.move_to_end('.*', last=False) # Move to the beginning
|
||||
|
||||
# Activate the extensions listed on settings.yaml
|
||||
@ -213,12 +205,18 @@ if __name__ == "__main__":
|
||||
|
||||
# If any model has been selected, load it
|
||||
if shared.model_name != 'None':
|
||||
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||
p = Path(shared.model_name)
|
||||
if p.exists():
|
||||
model_name = p.parts[-1]
|
||||
else:
|
||||
model_name = shared.model_name
|
||||
|
||||
model_settings = get_model_metadata(model_name)
|
||||
shared.settings.update(model_settings) # hijacking the interface defaults
|
||||
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
|
||||
|
||||
# Load the model
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
if shared.args.lora:
|
||||
add_lora_to_model(shared.args.lora)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user