Added Falcon LoRA training support (#2684)

I am 50% sure this will work
This commit is contained in:
MikoAL 2023-06-20 12:03:44 +08:00 committed by GitHub
parent c623e142ac
commit c40932eb39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -30,12 +30,14 @@ try:
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES} MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
except: except:
standard_modules = ["q_proj", "v_proj"] standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]} model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"], "rw":["query_key_value"]}
MODEL_CLASSES = { MODEL_CLASSES = {
"LlamaForCausalLM": "llama", "LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt", "OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj", "GPTJForCausalLM": "gptj",
"GPTNeoXForCausalLM": "gpt_neox" "GPTNeoXForCausalLM": "gpt_neox",
"RWForCausalLM": "rw"
} }
train_log = {} train_log = {}