From c6fe1ced0175020c158a2e72f2c1a21818446308 Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Sun, 16 Apr 2023 22:15:03 +0000 Subject: [PATCH] Add ChatGLM support (#1256) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- README.md | 1 + characters/instruction-following/ChatGLM.yaml | 3 +++ download-model.py | 2 +- models/config.yaml | 3 +++ modules/chat.py | 6 ++++-- modules/models.py | 21 ++++++++++++------- modules/shared.py | 5 +++++ 7 files changed, 31 insertions(+), 10 deletions(-) create mode 100644 characters/instruction-following/ChatGLM.yaml diff --git a/README.md b/README.md index 32c86431..5b8c1d2e 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,7 @@ Optionally, you can use the following command-line flags: | `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. | | `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. | | `--sdp-attention` | Use torch 2.0's sdp attention. | +| `--trust-remote-code` | Set trust_remote_code=True while loading a model. Necessary for ChatGLM. | #### llama.cpp diff --git a/characters/instruction-following/ChatGLM.yaml b/characters/instruction-following/ChatGLM.yaml new file mode 100644 index 00000000..02a26855 --- /dev/null +++ b/characters/instruction-following/ChatGLM.yaml @@ -0,0 +1,3 @@ +name: "答:" +your_name: "[Round <|round|>]\n问:" +context: "" diff --git a/download-model.py b/download-model.py index fc17e716..01ef2753 100644 --- a/download-model.py +++ b/download-model.py @@ -108,7 +108,7 @@ def get_download_links_from_huggingface(model, branch, text_only=False): is_safetensors = re.match(".*\.safetensors", fname) is_pt = re.match(".*\.pt", fname) is_ggml = re.match("ggml.*\.bin", fname) - is_tokenizer = re.match("tokenizer.*\.model", fname) + is_tokenizer = re.match("(tokenizer|ice).*\.model", fname) is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)): diff --git a/models/config.yaml b/models/config.yaml index 2cb09ed0..3ebf21f8 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -45,3 +45,6 @@ llama-[0-9]*b-4bit$: .*koala: mode: 'instruct' instruction_template: 'Koala' +.*chatglm: + mode: 'instruct' + instruction_template: 'ChatGLM' diff --git a/modules/chat.py b/modules/chat.py index 57acb898..e3c4e00f 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -49,7 +49,8 @@ def generate_chat_prompt(user_input, state, **kwargs): string = shared.history['internal'][i][0] if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: - rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n") + this_prefix1 = prefix1.replace('<|round|>', f'{i}') # for ChatGLM + rows.insert(1, f"{this_prefix1}{string.strip()}{state['end_of_turn']}\n") i -= 1 @@ -60,7 +61,8 @@ def generate_chat_prompt(user_input, state, **kwargs): # Adding the user message if len(user_input) > 0: - rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n") + this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM + rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n") # Adding the Character prefix rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) diff --git a/modules/models.py b/modules/models.py index 2a9007e0..ca9498d2 100644 --- a/modules/models.py +++ b/modules/models.py @@ -10,8 +10,8 @@ import numpy as np import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - BitsAndBytesConfig, LlamaTokenizer) +from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, + AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer) import modules.shared as shared from modules import llama_attn_hijack @@ -44,10 +44,16 @@ def load_model(model_name): shared.is_RWKV = 'rwkv-' in model_name.lower() shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0 + if 'chatglm' in model_name.lower(): + LoaderClass = AutoModel + trust_remote_code = shared.args.trust_remote_code + else: + LoaderClass = AutoModelForCausalLM + trust_remote_code = False # Load the model in simple 16-bit mode by default if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]): - model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code) if torch.has_mps: device = torch.device('mps') model = model.to(device) @@ -79,7 +85,7 @@ def load_model(model_name): # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") @@ -120,6 +126,7 @@ def load_model(model_name): params["torch_dtype"] = torch.float32 else: params["device_map"] = 'auto' + params["trust_remote_code"] = trust_remote_code if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)): params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) elif shared.args.load_in_8bit: @@ -156,7 +163,7 @@ def load_model(model_name): if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': config = AutoConfig.from_pretrained(checkpoint) with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) + model = LoaderClass.from_config(config) model.tie_weights() params['device_map'] = infer_auto_device_map( model, @@ -165,7 +172,7 @@ def load_model(model_name): no_split_module_classes=model._no_split_modules ) - model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) + model = LoaderClass.from_pretrained(checkpoint, **params) # Hijack attention with xformers if any((shared.args.xformers, shared.args.sdp_attention)): @@ -185,7 +192,7 @@ def load_model(model_name): except: pass else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code) print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer diff --git a/modules/shared.py b/modules/shared.py index 4a113502..0294073c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -113,6 +113,7 @@ parser.add_argument('--bf16', action='store_true', help='Load the model with bfl parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.') parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.") parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.") +parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_remote_code=True while loading a model. Necessary for ChatGLM.") # llama.cpp parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.') @@ -162,6 +163,10 @@ if args.cai_chat: print("Warning: --cai-chat is deprecated. Use --chat instead.") args.chat = True +# Security warnings +if args.trust_remote_code: + print("Warning: trust_remote_code is enabled. This is dangerous.") + def is_chat(): return args.chat