multimodal : add BakLLaVA conversion support (#3682)

This commit is contained in:
M. Yusuf Sarıgöz 2023-10-19 19:40:41 +03:00 committed by GitHub
parent 60abea9798
commit f3b25e4043
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,13 +16,29 @@ checkpoint = torch.load(path)
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")] mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
# store these tensors in a new dictionary and torch.save them # store these tensors in a new dictionary and torch.save them
projector = {name: checkpoint[name] for name in mm_tensors} projector = {name: checkpoint[name].float() for name in mm_tensors}
torch.save(projector, f"{args.model}/llava.projector") torch.save(projector, f"{args.model}/llava.projector")
# remove these tensors from the checkpoint and save it again # remove these tensors from the checkpoint and save it again
for name in mm_tensors: for name in mm_tensors:
del checkpoint[name] del checkpoint[name]
# BakLLaVA models contain CLIP tensors in it
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")]
if len(clip_tensors) > 0:
clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
torch.save(clip, f"{args.model}/llava.clip")
# remove these tensors
for name in clip_tensors:
del checkpoint[name]
# added tokens should be removed to be able to convert Mistral models
if os.path.exists(f"{args.model}/added_tokens.json"):
with open(f"{args.model}/added_tokens.json", "w") as f:
f.write("{}\n")
torch.save(checkpoint, path) torch.save(checkpoint, path)
print("Done!") print("Done!")