diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index 464ab80d3..c87606b4f 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -88,6 +88,8 @@ def main(args): else: raise ValueError() + local_model = False + model_path = "" model_name = args.model_name print("model_name: ", model_name) qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( @@ -97,8 +99,10 @@ def main(args): vcfg = cfg.vision_config if os.path.isdir(model_name): + local_model = True if model_name.endswith(os.sep): model_name = model_name[:-1] + model_path = model_name model_name = os.path.basename(model_name) fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf" @@ -139,7 +143,10 @@ def main(args): it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`. """ - processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) + if local_model: + processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path) + else: + processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue] fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]