diff --git a/modules/llama_attn_hijack.py b/modules/llama_attn_hijack.py index 925cdaa3..e4d1ceab 100644 --- a/modules/llama_attn_hijack.py +++ b/modules/llama_attn_hijack.py @@ -17,6 +17,7 @@ if shared.args.xformers: def hijack_llama_attention(): + import transformers.models.llama.modeling_llama if shared.args.xformers: transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward logger.info("Replaced attention with xformers_attention")