Add --load-in-4bit parameter (#2320)

This commit is contained in:
oobabooga 2023-05-25 01:14:13 -03:00 committed by GitHub
parent 63ce5f9c28
commit 361451ba60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 61 additions and 22 deletions

View File

@ -214,13 +214,22 @@ Optionally, you can use the following command-line flags:
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-8bit` | Load the model with 8-bit precision (using bitsandbytes).|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
| `--sdp-attention` | Use torch 2.0's sdp attention. |
| `--trust-remote-code` | Set trust_remote_code=True while loading a model. Necessary for ChatGLM. |
#### Accelerate 4-bit
| Flag | Description |
|---------------------------------------------|-------------|
| `--load-in-4bit` | Load the model with 4-bit precision (using bitsandbytes). |
| `--compute_dtype COMPUTE_DTYPE` | compute dtype for 4-bit. Valid options: bfloat16, float16, float32. |
| `--quant_type QUANT_TYPE` | quant_type for 4-bit. Valid options: nf4, fp4. |
| `--use_double_quant` | use_double_quant for 4-bit. |
#### llama.cpp
| Flag | Description |

View File

@ -149,7 +149,7 @@ def huggingface_loader(model_name):
LoaderClass = AutoModelForCausalLM
# Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
if torch.has_mps:
device = torch.device('mps')
@ -179,7 +179,21 @@ def huggingface_loader(model_name):
params["torch_dtype"] = torch.float32
else:
params["device_map"] = 'auto'
if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
if shared.args.load_in_4bit:
# See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
'load_in_4bit': True,
'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
'bnb_4bit_quant_type': shared.args.quant_type,
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
}
logger.warning("Using the following 4-bit params: " + str(quantization_config_params))
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
elif shared.args.load_in_8bit:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)

View File

@ -114,13 +114,19 @@ parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memo
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_remote_code=True while loading a model. Necessary for ChatGLM.")
# Accelerate 4-bit
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).')
parser.add_argument('--compute_dtype', type=str, default="bfloat16", help="compute dtype for 4-bit. Valid options: bfloat16, float16, float32.")
parser.add_argument('--quant_type', type=str, default="nf4", help='quant_type for 4-bit. Valid options: nf4, fp4.')
parser.add_argument('--use_double_quant', action='store_true', help='use_double_quant for 4-bit.')
# llama.cpp
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')

View File

@ -30,9 +30,10 @@ theme = gr.themes.Default(
def list_model_elements():
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers']
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers']
for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}')
return elements

View File

@ -1,4 +1,3 @@
accelerate==0.19.0
colorama
datasets
flexgen==0.1.7
@ -12,9 +11,10 @@ pyyaml
requests
safetensors==0.3.1
sentencepiece
transformers==4.29.2
tqdm
git+https://github.com/huggingface/peft@4fd374e80d670781c0d82c96ce94d1215ff23306
bitsandbytes==0.38.1; platform_system != "Windows"
git+https://github.com/huggingface/peft@3714aa2fff158fdfa637b2b65952580801d890b2
git+https://github.com/huggingface/transformers@e45e756d22206ca8fa9fb057c8c3d8fa79bf81c6
git+https://github.com/huggingface/accelerate@0226f750257b3bf2cadc4f189f9eef0c764a0467
bitsandbytes==0.39.0; platform_system != "Windows"
llama-cpp-python==0.1.53; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.53/llama_cpp_python-0.1.53-cp310-cp310-win_amd64.whl; platform_system == "Windows"

View File

@ -353,11 +353,12 @@ def create_model_menus():
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown('Transformers parameters')
gr.Markdown('Transformers')
with gr.Row():
with gr.Column():
for i in range(len(total_mem)):
shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
with gr.Column():
@ -367,9 +368,26 @@ def create_model_menus():
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
with gr.Box():
gr.Markdown('Transformers 4-bit')
with gr.Row():
with gr.Column():
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
with gr.Column():
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype)
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type)
with gr.Row():
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
shared.gradio['download_model_button'] = gr.Button("Download")
with gr.Column():
with gr.Box():
gr.Markdown('GPTQ parameters')
gr.Markdown('GPTQ')
with gr.Row():
with gr.Column():
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
@ -379,17 +397,8 @@ def create_model_menus():
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
shared.gradio['download_model_button'] = gr.Button("Download")
with gr.Column():
with gr.Box():
gr.Markdown('llama.cpp parameters')
gr.Markdown('llama.cpp')
with gr.Row():
with gr.Column():
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
@ -978,7 +987,7 @@ def create_interface():
shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}")
if shared.settings['dark_theme']:
shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => document.getElementsByTagName('body')[0].classList.add('dark')")
shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)