Merge pull request #3695 from oobabooga/gguf2

GGUF
This commit is contained in:
oobabooga 2023-08-27 02:33:26 -03:00 committed by GitHub
commit e6eda5c2da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 121 additions and 41 deletions

View File

@ -156,7 +156,7 @@ text-generation-webui
In the "Model" tab of the UI, those models can be automatically downloaded from Hugging Face. You can also download them via the command-line with `python download-model.py organization/model`. In the "Model" tab of the UI, those models can be automatically downloaded from Hugging Face. You can also download them via the command-line with `python download-model.py organization/model`.
* GGML models are a single file and should be placed directly into `models`. Example: * GGML/GGUF models are a single file and should be placed directly into `models`. Example:
``` ```
text-generation-webui text-generation-webui
@ -258,7 +258,7 @@ Optionally, you can use the following command-line flags:
| `--quant_type QUANT_TYPE` | quant_type for 4-bit. Valid options: nf4, fp4. | | `--quant_type QUANT_TYPE` | quant_type for 4-bit. Valid options: nf4, fp4. |
| `--use_double_quant` | use_double_quant for 4-bit. | | `--use_double_quant` | use_double_quant for 4-bit. |
#### GGML (for llama.cpp and ctransformers) #### GGML/GGUF (for llama.cpp and ctransformers)
| Flag | Description | | Flag | Description |
|-------------|-------------| |-------------|-------------|
@ -269,16 +269,16 @@ Optionally, you can use the following command-line flags:
#### llama.cpp #### llama.cpp
| Flag | Description | | Flag | Description |
|-------------|-------------| |---------------|---------------|
| `--no-mmap` | Prevent mmap from being used. | | `--no-mmap` | Prevent mmap from being used. |
| `--mlock` | Force the system to keep the model in RAM. | | `--mlock` | Force the system to keep the model in RAM. |
| `--mul_mat_q` | Activate new mulmat kernels. | | `--mul_mat_q` | Activate new mulmat kernels. |
| `--cache-capacity CACHE_CAPACITY` | Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. | | `--cache-capacity CACHE_CAPACITY` | Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. |
| `--tensor_split TENSOR_SPLIT` | Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17 | | `--tensor_split TENSOR_SPLIT` | Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17 |
| `--llama_cpp_seed SEED` | Seed for llama-cpp models. Default 0 (random). | | `--llama_cpp_seed SEED` | Seed for llama-cpp models. Default 0 (random). |
| `--n_gqa N_GQA` | grouped-query attention. Must be 8 for llama-2 70b. | | `--n_gqa N_GQA` | GGML only (not used by GGUF): Grouped-Query Attention. Must be 8 for llama-2 70b. |
| `--rms_norm_eps RMS_NORM_EPS` | 5e-6 is a good value for llama-2 models. | | `--rms_norm_eps RMS_NORM_EPS` | GGML only (not used by GGUF): 5e-6 is a good value for llama-2 models. |
| `--cpu` | Use the CPU version of llama-cpp-python instead of the GPU-accelerated version. | | `--cpu` | Use the CPU version of llama-cpp-python instead of the GPU-accelerated version. |
|`--cfg-cache` | llamacpp_HF: Create an additional cache for CFG negative prompts. | |`--cfg-cache` | llamacpp_HF: Create an additional cache for CFG negative prompts. |

View File

@ -57,7 +57,8 @@ class ModelDownloader:
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_pt = False has_pt = False
# has_ggml = False has_gguf = False
has_ggml = False
has_safetensors = False has_safetensors = False
is_lora = False is_lora = False
while True: while True:
@ -78,10 +79,11 @@ class ModelDownloader:
is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname) is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname)
is_safetensors = re.match(r".*\.safetensors", fname) is_safetensors = re.match(r".*\.safetensors", fname)
is_pt = re.match(r".*\.pt", fname) is_pt = re.match(r".*\.pt", fname)
is_gguf = re.match(r'.*\.gguf', fname)
is_ggml = re.match(r".*ggml.*\.bin", fname) is_ggml = re.match(r".*ggml.*\.bin", fname)
is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname)
is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)): if any((is_pytorch, is_safetensors, is_pt, is_gguf, is_ggml, is_tokenizer, is_text)):
if 'lfs' in dict[i]: if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']]) sha256.append([fname, dict[i]['lfs']['oid']])
@ -101,8 +103,11 @@ class ModelDownloader:
elif is_pt: elif is_pt:
has_pt = True has_pt = True
classifications.append('pt') classifications.append('pt')
elif is_gguf:
has_gguf = True
classifications.append('gguf')
elif is_ggml: elif is_ggml:
# has_ggml = True has_ggml = True
classifications.append('ggml') classifications.append('ggml')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
@ -115,6 +120,12 @@ class ModelDownloader:
if classifications[i] in ['pytorch', 'pt']: if classifications[i] in ['pytorch', 'pt']:
links.pop(i) links.pop(i)
# If both GGML and GGUF are available, download GGUF only
if has_ggml and has_gguf:
for i in range(len(classifications) - 1, -1, -1):
if classifications[i] == 'ggml':
links.pop(i)
return links, sha256, is_lora return links, sha256, is_lora
def get_output_folder(self, model, branch, is_lora, base_folder=None): def get_output_folder(self, model, branch, is_lora, base_folder=None):

View File

@ -9,27 +9,43 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import RoPE, shared from modules import RoPE, shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.utils import is_gguf
import llama_cpp import llama_cpp
try:
import llama_cpp_ggml
except:
llama_cpp_ggml = llama_cpp
if torch.cuda.is_available() and not torch.version.hip: if torch.cuda.is_available() and not torch.version.hip:
try: try:
import llama_cpp_cuda import llama_cpp_cuda
except: except:
llama_cpp_cuda = None llama_cpp_cuda = None
try:
import llama_cpp_ggml_cuda
except:
llama_cpp_ggml_cuda = llama_cpp_cuda
else: else:
llama_cpp_cuda = None llama_cpp_cuda = None
llama_cpp_ggml_cuda = None
def llama_cpp_lib(): def llama_cpp_lib(model_file: Union[str, Path] = None):
if shared.args.cpu or llama_cpp_cuda is None: if model_file is not None:
return llama_cpp gguf_model = is_gguf(model_file)
else: else:
return llama_cpp_cuda gguf_model = True
if shared.args.cpu or llama_cpp_cuda is None:
return llama_cpp if gguf_model else llama_cpp_ggml
else:
return llama_cpp_cuda if gguf_model else llama_cpp_ggml_cuda
class LlamacppHF(PreTrainedModel): class LlamacppHF(PreTrainedModel):
def __init__(self, model): def __init__(self, model, path):
super().__init__(PretrainedConfig()) super().__init__(PretrainedConfig())
self.model = model self.model = model
self.generation_config = GenerationConfig() self.generation_config = GenerationConfig()
@ -48,7 +64,7 @@ class LlamacppHF(PreTrainedModel):
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids.copy(), 'input_ids': self.model.input_ids.copy(),
'scores': self.model.scores.copy(), 'scores': self.model.scores.copy(),
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params) 'ctx': llama_cpp_lib(path).llama_new_context_with_model(model.model, model.params)
} }
def _validate_model_class(self): def _validate_model_class(self):
@ -165,7 +181,7 @@ class LlamacppHF(PreTrainedModel):
if path.is_file(): if path.is_file():
model_file = path model_file = path
else: else:
model_file = list(path.glob('*ggml*.bin'))[0] model_file = (list(path.glob('*.gguf*')) + list(path.glob('*ggml*.bin')))[0]
logger.info(f"llama.cpp weights detected: {model_file}\n") logger.info(f"llama.cpp weights detected: {model_file}\n")
@ -188,12 +204,17 @@ class LlamacppHF(PreTrainedModel):
'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base), 'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base),
'tensor_split': tensor_split_list, 'tensor_split': tensor_split_list,
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
'logits_all': True, 'logits_all': True,
} }
Llama = llama_cpp_lib().Llama if not is_gguf(model_file):
ggml_params = {
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
}
params = params | ggml_params
Llama = llama_cpp_lib(model_file).Llama
model = Llama(**params) model = Llama(**params)
return LlamacppHF(model) return LlamacppHF(model, model_file)

View File

@ -1,5 +1,7 @@
import re import re
from functools import partial from functools import partial
from pathlib import Path
from typing import Union
import torch import torch
@ -7,23 +9,39 @@ from modules import RoPE, shared
from modules.callbacks import Iteratorize from modules.callbacks import Iteratorize
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length
from modules.utils import is_gguf
import llama_cpp import llama_cpp
try:
import llama_cpp_ggml
except:
llama_cpp_ggml = llama_cpp
if torch.cuda.is_available() and not torch.version.hip: if torch.cuda.is_available() and not torch.version.hip:
try: try:
import llama_cpp_cuda import llama_cpp_cuda
except: except:
llama_cpp_cuda = None llama_cpp_cuda = None
try:
import llama_cpp_ggml_cuda
except:
llama_cpp_ggml_cuda = llama_cpp_cuda
else: else:
llama_cpp_cuda = None llama_cpp_cuda = None
llama_cpp_ggml_cuda = None
def llama_cpp_lib(): def llama_cpp_lib(model_file: Union[str, Path] = None):
if shared.args.cpu or llama_cpp_cuda is None: if model_file is not None:
return llama_cpp gguf_model = is_gguf(model_file)
else: else:
return llama_cpp_cuda gguf_model = True
if shared.args.cpu or llama_cpp_cuda is None:
return llama_cpp if gguf_model else llama_cpp_ggml
else:
return llama_cpp_cuda if gguf_model else llama_cpp_ggml_cuda
def ban_eos_logits_processor(eos_token, input_ids, logits): def ban_eos_logits_processor(eos_token, input_ids, logits):
@ -41,8 +59,8 @@ class LlamaCppModel:
@classmethod @classmethod
def from_pretrained(self, path): def from_pretrained(self, path):
Llama = llama_cpp_lib().Llama Llama = llama_cpp_lib(path).Llama
LlamaCache = llama_cpp_lib().LlamaCache LlamaCache = llama_cpp_lib(path).LlamaCache
result = self() result = self()
cache_capacity = 0 cache_capacity = 0
@ -75,10 +93,15 @@ class LlamaCppModel:
'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base), 'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base),
'tensor_split': tensor_split_list, 'tensor_split': tensor_split_list,
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
} }
if not is_gguf(path):
ggml_params = {
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
}
params = params | ggml_params
result.model = Llama(**params) result.model = Llama(**params)
if cache_capacity > 0: if cache_capacity > 0:
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))

View File

@ -241,7 +241,7 @@ def llamacpp_loader(model_name):
if path.is_file(): if path.is_file():
model_file = path model_file = path
else: else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] model_file = (list(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf*')) + list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin')))[0]
logger.info(f"llama.cpp weights detected: {model_file}") logger.info(f"llama.cpp weights detected: {model_file}")
model, tokenizer = LlamaCppModel.from_pretrained(model_file) model, tokenizer = LlamaCppModel.from_pretrained(model_file)

View File

@ -24,9 +24,9 @@ def infer_loader(model_name):
loader = None loader = None
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0): elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
loader = 'AutoGPTQ' loader = 'AutoGPTQ'
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0: elif len(list(path_to_model.glob('*.gguf*')) + list(path_to_model.glob('*ggml*.bin'))) > 0:
loader = 'llama.cpp' loader = 'llama.cpp'
elif re.match(r'.*ggml.*\.bin', model_name.lower()): elif re.match(r'.*\.gguf|.*ggml.*\.bin', model_name.lower()):
loader = 'llama.cpp' loader = 'llama.cpp'
elif re.match(r'.*rwkv.*\.pth', model_name.lower()): elif re.match(r'.*rwkv.*\.pth', model_name.lower()):
loader = 'RWKV' loader = 'RWKV'

View File

@ -80,8 +80,8 @@ def create_ui():
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx) shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx)
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads) shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch) shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='grouped-query attention. Must be 8 for llama-2 70b.') shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='GGML only (not used by GGUF): Grouped-Query Attention. Must be 8 for llama-2 70b.')
shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.rms_norm_eps, info='5e-6 is a good value for llama-2 models.') shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.rms_norm_eps, info='GGML only (not used by GGUF): 5e-6 is a good value for llama-2 models.')
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None") shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None") shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")

View File

@ -2,6 +2,7 @@ import os
import re import re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Union
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -124,3 +125,15 @@ def get_datasets(path: str, ext: str):
def get_available_chat_styles(): def get_available_chat_styles():
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys) return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
def is_gguf(path: Union[str, Path]) -> bool:
'''
Determines if a llama.cpp model is in GGUF format
Copied from ctransformers utils.py
'''
path = str(Path(path).resolve())
with open(path, "rb") as f:
magic = f.read(4)
return magic == "GGUF".encode()

View File

@ -22,19 +22,31 @@ tensorboard
tqdm tqdm
wandb wandb
# bitsandbytes
bitsandbytes==0.41.1; platform_system != "Windows" bitsandbytes==0.41.1; platform_system != "Windows"
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
# AutoGPTQ
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
# ExLlama
https://github.com/jllllll/exllama/releases/download/0.0.10/exllama-0.0.10+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/exllama/releases/download/0.0.10/exllama-0.0.10+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/exllama/releases/download/0.0.10/exllama-0.0.10+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/jllllll/exllama/releases/download/0.0.10/exllama-0.0.10+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
# llama-cpp-python without GPU support # llama-cpp-python without GPU support
llama-cpp-python==0.1.78; platform_system != "Windows" llama-cpp-python==0.1.79; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.78/llama_cpp_python-0.1.78-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.79/llama_cpp_python-0.1.79-cp310-cp310-win_amd64.whl; platform_system == "Windows"
# llama-cpp-python with CUDA support # llama-cpp-python with CUDA support
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.1.78+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.1.79+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.1.78+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.1.79+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
# llama-cpp-python with GGML support
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python_ggml-0.1.78+cpuavx2-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python_ggml-0.1.78+cpuavx2-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_ggml_cuda-0.1.78+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_ggml_cuda-0.1.78+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
# GPTQ-for-LLaMa # GPTQ-for-LLaMa
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"