Add ctransformers support (#3313)

---------

Co-authored-by: cal066 <cal066@users.noreply.github.com>
Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
Co-authored-by: randoentity <137087500+randoentity@users.noreply.github.com>
This commit is contained in:
cal066 2023-08-11 17:41:33 +00:00 committed by GitHub
parent 8dbaa20ca8
commit 7a4fcee069
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 188 additions and 43 deletions

View File

@ -11,7 +11,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
## Features ## Features
* 3 interface modes: default, notebook, and chat * 3 interface modes: default, notebook, and chat
* Multiple model backends: transformers, llama.cpp, ExLlama, AutoGPTQ, GPTQ-for-LLaMa * Multiple model backends: transformers, llama.cpp, ExLlama, AutoGPTQ, GPTQ-for-LLaMa, ctransformers
* Dropdown menu for quickly switching between different models * Dropdown menu for quickly switching between different models
* LoRA: load and unload LoRAs on the fly, train a new LoRA * LoRA: load and unload LoRAs on the fly, train a new LoRA
* Precise instruction templates for chat mode, including Llama 2, Alpaca, Vicuna, WizardLM, StableLM, and many others * Precise instruction templates for chat mode, including Llama 2, Alpaca, Vicuna, WizardLM, StableLM, and many others

View File

@ -0,0 +1,76 @@
from ctransformers import AutoConfig, AutoModelForCausalLM
from modules import shared
from modules.callbacks import Iteratorize
from modules.logging_colors import logger
class CtransformersModel:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path):
result = self()
# ctransformers uses -1 for random seed
config = AutoConfig.from_pretrained(
str(path),
threads=shared.args.threads,
gpu_layers=shared.args.n_gpu_layers,
batch_size=shared.args.n_batch,
stream=True,
seed=(-1 if shared.args.llama_cpp_seed == 0 else shared.args.llama_cpp_seed)
)
self.model = AutoModelForCausalLM.from_pretrained(
str(result.model_dir(path) if result.model_type_is_auto() else path),
model_type=(None if result.model_type_is_auto() else shared.args.model_type),
config=config
)
logger.info(f'Using ctransformers model_type: {self.model.model_type} for {self.model.model_path}')
return result, result
def model_type_is_auto(self):
return shared.args.model_type == "Auto" or shared.args.model_type == "None"
def model_dir(self, path):
if path.is_file():
return path.parent
return path
def encode(self, string, **kwargs):
return self.model.tokenize(string)
def decode(self, ids):
return self.model.detokenize(ids)
def generate(self, prompt, state, callback=None):
prompt = prompt if type(prompt) is str else prompt.decode()
generator = self.model._stream(
prompt=prompt,
max_new_tokens=state['max_new_tokens'],
temperature=state['temperature'],
top_p=state['top_p'],
top_k=state['top_k'],
repetition_penalty=state['repetition_penalty'],
threads=shared.args.threads
)
output = ""
for token in generator:
if callback:
callback(token)
output += token
return output
def generate_with_streaming(self, *args, **kwargs):
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
reply = ''
for token in generator:
reply += token
yield reply

View File

@ -1,10 +1,43 @@
import functools import functools
from collections import OrderedDict
import gradio as gr import gradio as gr
from modules import shared from modules import shared
loaders_and_params = { loaders_and_params = OrderedDict({
'Transformers': [
'cpu_memory',
'gpu_memory',
'trust_remote_code',
'load_in_8bit',
'bf16',
'cpu',
'disk',
'auto_devices',
'load_in_4bit',
'use_double_quant',
'quant_type',
'compute_dtype',
'trust_remote_code',
'alpha_value',
'compress_pos_emb',
'transformers_info'
],
'ExLlama_HF': [
'gpu_split',
'max_seq_len',
'alpha_value',
'compress_pos_emb',
'exllama_HF_info',
],
'ExLlama': [
'gpu_split',
'max_seq_len',
'alpha_value',
'compress_pos_emb',
'exllama_info',
],
'AutoGPTQ': [ 'AutoGPTQ': [
'triton', 'triton',
'no_inject_fused_attention', 'no_inject_fused_attention',
@ -59,39 +92,15 @@ loaders_and_params = {
'cpu', 'cpu',
'llamacpp_HF_info', 'llamacpp_HF_info',
], ],
'Transformers': [ 'ctransformers': [
'cpu_memory', 'n_ctx',
'gpu_memory', 'n_gpu_layers',
'trust_remote_code', 'n_batch',
'load_in_8bit', 'threads',
'bf16', 'model_type',
'cpu', 'llama_cpp_seed',
'disk',
'auto_devices',
'load_in_4bit',
'use_double_quant',
'quant_type',
'compute_dtype',
'trust_remote_code',
'alpha_value',
'compress_pos_emb',
'transformers_info'
],
'ExLlama': [
'gpu_split',
'max_seq_len',
'alpha_value',
'compress_pos_emb',
'exllama_info',
],
'ExLlama_HF': [
'gpu_split',
'max_seq_len',
'alpha_value',
'compress_pos_emb',
'exllama_HF_info',
] ]
} })
loaders_samplers = { loaders_samplers = {
'Transformers': { 'Transformers': {
@ -256,6 +265,33 @@ loaders_samplers = {
'skip_special_tokens', 'skip_special_tokens',
'auto_max_new_tokens', 'auto_max_new_tokens',
}, },
'ctransformers': {
'temperature',
'top_p',
'top_k',
'repetition_penalty',
}
}
loaders_model_types = {
'GPTQ-for-LLaMa': [
"None",
"llama",
"opt",
"gptj"
],
'ctransformers': [
"None",
"gpt2",
"gptj",
"gptneox",
"llama",
"mpt",
"dollyv2"
"replit",
"starcoder",
"falcon"
],
} }
@ -277,6 +313,13 @@ def blacklist_samplers(loader):
return [gr.update(visible=True) if sampler in loaders_samplers[loader] else gr.update(visible=False) for sampler in all_samplers] return [gr.update(visible=True) if sampler in loaders_samplers[loader] else gr.update(visible=False) for sampler in all_samplers]
def get_model_types(loader):
if loader in loaders_model_types:
return loaders_model_types[loader]
return ["None"]
def get_gpu_memory_keys(): def get_gpu_memory_keys():
return [k for k in shared.gradio if k.startswith('gpu_memory')] return [k for k in shared.gradio if k.startswith('gpu_memory')]

View File

@ -58,7 +58,8 @@ def load_model(model_name, loader=None):
'llamacpp_HF': llamacpp_HF_loader, 'llamacpp_HF': llamacpp_HF_loader,
'RWKV': RWKV_loader, 'RWKV': RWKV_loader,
'ExLlama': ExLlama_loader, 'ExLlama': ExLlama_loader,
'ExLlama_HF': ExLlama_HF_loader 'ExLlama_HF': ExLlama_HF_loader,
'ctransformers': ctransformers_loader,
} }
p = Path(model_name) p = Path(model_name)
@ -242,7 +243,7 @@ def llamacpp_loader(model_name):
else: else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
logger.info(f"llama.cpp weights detected: {model_file}\n") logger.info(f"llama.cpp weights detected: {model_file}")
model, tokenizer = LlamaCppModel.from_pretrained(model_file) model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer return model, tokenizer
@ -268,6 +269,24 @@ def llamacpp_HF_loader(model_name):
return model, tokenizer return model, tokenizer
def ctransformers_loader(model_name):
from modules.ctransformers_model import CtransformersModel
path = Path(f'{shared.args.model_dir}/{model_name}')
ctrans = CtransformersModel()
if ctrans.model_type_is_auto():
model_file = path
else:
if path.is_file():
model_file = path
else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*.bin'))[0]
logger.info(f'ctransformers weights detected: {model_file}')
model, tokenizer = ctrans.from_pretrained(model_file)
return model, tokenizer
def GPTQ_loader(model_name): def GPTQ_loader(model_name):
# Monkey patch # Monkey patch

View File

@ -215,6 +215,8 @@ def fix_loader_name(name):
return 'ExLlama' return 'ExLlama'
elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']: elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']:
return 'ExLlama_HF' return 'ExLlama_HF'
elif name in ['ctransformers', 'ctranforemrs', 'ctransformer']:
return 'ctransformers'
def add_extension(name): def add_extension(name):

View File

@ -41,7 +41,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
yield '' yield ''
return return
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel']:
generate_func = generate_reply_custom generate_func = generate_reply_custom
else: else:
generate_func = generate_reply_HF generate_func = generate_reply_HF
@ -90,7 +90,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel']:
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids)) input_ids = np.array(input_ids).reshape(1, len(input_ids))
else: else:
@ -104,7 +104,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None: if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:] input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel'] or shared.args.cpu:
return input_ids return input_ids
elif shared.args.deepspeed: elif shared.args.deepspeed:
return input_ids.to(device=local_rank) return input_ids.to(device=local_rank)

View File

@ -63,7 +63,7 @@ def create_ui():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value=None) shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys(), value=None)
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -84,7 +84,7 @@ def create_ui():
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None") shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None") shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None") shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None"], value=shared.args.model_type or "None")
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.') shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
@ -127,7 +127,9 @@ def create_ui():
def create_event_handlers(): def create_event_handlers():
shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params())) shared.gradio['loader'].change(
loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params())).then(
lambda value: gr.update(choices=loaders.get_model_types(value)), gradio('loader'), gradio('model_type'))
# In this event handler, the interface state is read and updated # In this event handler, the interface state is read and updated
# with the model defaults (if any), and then the model is loaded # with the model defaults (if any), and then the model is loaded

View File

@ -16,7 +16,7 @@ def create_ui(default_preset):
shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button') shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button')
with gr.Column(): with gr.Column():
shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All", "Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value="All", elem_classes='slim-dropdown') shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All"] + list(loaders.loaders_and_params.keys()), value="All", elem_classes='slim-dropdown')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():

View File

@ -40,3 +40,6 @@ https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/text
# GPTQ-for-LLaMa # GPTQ-for-LLaMa
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/jllllll/GPTQ-for-LLaMa-CUDA/releases/download/0.1.0/gptq_for_llama-0.1.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
# ctransformers
https://github.com/jllllll/ctransformers-cuBLAS-wheels/releases/download/AVX2/ctransformers-0.2.20+cu117-py3-none-any.whl