Remove flexgen support

This commit is contained in:
oobabooga 2023-07-25 15:15:29 -07:00
parent 5134d5b1c6
commit 75c2dd38cf
8 changed files with 3 additions and 233 deletions

View File

@ -1,63 +0,0 @@
'''
Converts a transformers model to a format compatible with flexgen.
'''
import argparse
import os
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args()
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
global torch_linear_init_backup
global torch_layer_norm_init_backup
torch_linear_init_backup = torch.nn.Linear.reset_parameters
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init():
"""Rollback the change made by disable_torch_init."""
import torch
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
if __name__ == '__main__':
path = Path(args.MODEL)
model_name = path.name
print(f"Loading {model_name}...")
# disable_torch_init()
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# restore_torch_init()
tokenizer = AutoTokenizer.from_pretrained(path)
out_folder = Path(f"models/{model_name}-np")
if not Path(out_folder).exists():
os.mkdir(out_folder)
print(f"Saving the converted model to {out_folder}...")
for name, param in tqdm(list(model.model.named_parameters())):
name = name.replace("decoder.final_layer_norm", "decoder.layer_norm")
param_path = os.path.join(out_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())

View File

@ -1,64 +0,0 @@
>FlexGen is a high-throughput generation engine for running large language models with limited GPU memory (e.g., a 16GB T4 GPU or a 24GB RTX3090 gaming card!).
https://github.com/FMInference/FlexGen
## Installation
No additional installation steps are necessary. FlexGen is in the `requirements.txt` file for this project.
## Converting a model
FlexGen only works with the OPT model, and it needs to be converted to numpy format before starting the web UI:
```
python convert-to-flexgen.py models/opt-1.3b/
```
The output will be saved to `models/opt-1.3b-np/`.
## Usage
The basic command is the following:
```
python server.py --model opt-1.3b --loader flexgen
```
For large models, the RAM usage may be too high and your computer may freeze. If that happens, you can try this:
```
python server.py --model opt-1.3b --loader flexgen --compress-weight
```
With this second command, I was able to run both OPT-6.7b and OPT-13B with **2GB VRAM**, and the speed was good in both cases.
You can also manually set the offload strategy with
```
python server.py --model opt-1.3b --loader flexgen --percent 0 100 100 0 100 0
```
where the six numbers after `--percent` are:
```
the percentage of weight on GPU
the percentage of weight on CPU
the percentage of attention cache on GPU
the percentage of attention cache on CPU
the percentage of activations on GPU
the percentage of activations on CPU
```
You should typically only change the first two numbers. If their sum is less than 100, the remaining layers will be offloaded to the disk, by default into the `text-generation-webui/cache` folder.
## Performance
In my experiments with OPT-30B using a RTX 3090 on Linux, I have obtained these results:
* `--loader flexgen --compress-weight --percent 0 100 100 0 100 0`: 0.99 seconds per token.
* `--loader flexgen --compress-weight --percent 100 0 100 0 100 0`: 0.765 seconds per token.
## Limitations
* Only works with the OPT models.
* Only two generation parameters are available: `temperature` and `do_sample`.

View File

@ -56,7 +56,6 @@ def load_model(model_name, loader=None):
'GPTQ-for-LLaMa': GPTQ_loader, 'GPTQ-for-LLaMa': GPTQ_loader,
'llama.cpp': llamacpp_loader, 'llama.cpp': llamacpp_loader,
'llamacpp_HF': llamacpp_HF_loader, 'llamacpp_HF': llamacpp_HF_loader,
'FlexGen': flexgen_loader,
'RWKV': RWKV_loader, 'RWKV': RWKV_loader,
'ExLlama': ExLlama_loader, 'ExLlama': ExLlama_loader,
'ExLlama_HF': ExLlama_HF_loader 'ExLlama_HF': ExLlama_HF_loader
@ -221,32 +220,6 @@ def huggingface_loader(model_name):
return model return model
def flexgen_loader(model_name):
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
# Initialize environment
env = ExecutionEnv.create(shared.args.disk_cache_dir)
# Offloading policy
policy = Policy(1, 1,
shared.args.percent[0], shared.args.percent[1],
shared.args.percent[2], shared.args.percent[3],
shared.args.percent[4], shared.args.percent[5],
overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
cpu_cache_compute=False, attn_sparsity=1.0,
compress_weight=shared.args.compress_weight,
comp_weight_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=0, symmetric=False),
compress_cache=False,
comp_cache_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=2, symmetric=False))
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
return model
def RWKV_loader(model_name): def RWKV_loader(model_name):
from modules.RWKV import RWKVModel, RWKVTokenizer from modules.RWKV import RWKVModel, RWKVTokenizer

View File

@ -30,8 +30,6 @@ def infer_loader(model_name):
loader = 'llama.cpp' loader = 'llama.cpp'
elif re.match('.*rwkv.*\.pth', model_name.lower()): elif re.match('.*rwkv.*\.pth', model_name.lower()):
loader = 'RWKV' loader = 'RWKV'
elif shared.args.flexgen:
loader = 'FlexGen'
else: else:
loader = 'Transformers' loader = 'Transformers'

View File

@ -95,7 +95,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
# Model loader # Model loader
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen') parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv')
# Accelerate/transformers # Accelerate/transformers
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
@ -156,7 +156,6 @@ parser.add_argument('--gpu-split', type=str, help="Comma-separated list of VRAM
parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.") parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.")
# FlexGen # FlexGen
parser.add_argument('--flexgen', action='store_true', help='DEPRECATED')
parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).') parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.") parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).") parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
@ -202,9 +201,6 @@ if args.autogptq:
if args.gptq_for_llama: if args.gptq_for_llama:
logger.warning('--gptq-for-llama has been deprecated and will be removed soon. Use --loader gptq-for-llama instead.') logger.warning('--gptq-for-llama has been deprecated and will be removed soon. Use --loader gptq-for-llama instead.')
args.loader = 'gptq-for-llama' args.loader = 'gptq-for-llama'
if args.flexgen:
logger.warning('--flexgen has been deprecated and will be removed soon. Use --loader flexgen instead.')
args.loader = 'FlexGen'
# Security warnings # Security warnings
if args.trust_remote_code: if args.trust_remote_code:

View File

@ -53,8 +53,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
return input_ids return input_ids
elif shared.args.flexgen:
return input_ids.numpy()
elif shared.args.deepspeed: elif shared.args.deepspeed:
return input_ids.to(device=local_rank) return input_ids.to(device=local_rank)
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
@ -182,8 +180,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
generate_func = generate_reply_custom generate_func = generate_reply_custom
elif shared.args.flexgen:
generate_func = generate_reply_flexgen
else: else:
generate_func = generate_reply_HF generate_func = generate_reply_HF
@ -339,66 +335,3 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str
new_tokens = len(encode(original_question + reply)[0]) - original_tokens new_tokens = len(encode(original_question + reply)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return
def generate_reply_flexgen(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]
if state['stream']:
generate_params['max_new_tokens'] = 8
# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
# Find the eos tokens
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if not state['ban_eos_token']:
generate_params['stop'] = eos_token_ids[-1]
# Add the encoded tokens to generate_params
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
t0 = time.time()
try:
if not is_chat:
yield ''
# Generate the entire reply at once.
if not state['stream']:
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(state['max_new_tokens'] // 8 + 1):
if shared.stop_everything:
break
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
input_ids = np.reshape(output, (1, output.shape[0]))
generate_params.update({'inputs': input_ids})
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(original_input_ids[0])
new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return

View File

@ -71,10 +71,7 @@ def natural_keys(text):
def get_available_models(): def get_available_models():
if shared.args.flexgen: return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml'))], key=natural_keys)
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=natural_keys)
else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml'))], key=natural_keys)
def get_available_presets(): def get_available_presets():

View File

@ -321,7 +321,7 @@ def create_settings_menus(default_preset):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset', elem_classes='slim-dropdown') shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset, label='Generation parameters preset', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button') ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button')
shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button') shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button')
shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button') shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button')