mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
|
import sys
|
||
|
import os
|
||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||
|
from embd_input import MyModel
|
||
|
import numpy as np
|
||
|
from torch import nn
|
||
|
import torch
|
||
|
from transformers import CLIPVisionModel, CLIPImageProcessor
|
||
|
from PIL import Image
|
||
|
|
||
|
# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
|
||
|
vision_tower = "openai/clip-vit-large-patch14"
|
||
|
select_hidden_state_layer = -2
|
||
|
# (vision_config.image_size // vision_config.patch_size) ** 2
|
||
|
image_token_len = (224//14)**2
|
||
|
|
||
|
class Llava:
|
||
|
def __init__(self, args):
|
||
|
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
||
|
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
||
|
self.mm_projector = nn.Linear(1024, 5120)
|
||
|
self.model = MyModel(["main", *args])
|
||
|
|
||
|
def load_projection(self, path):
|
||
|
state = torch.load(path)
|
||
|
self.mm_projector.load_state_dict({
|
||
|
"weight": state["model.mm_projector.weight"],
|
||
|
"bias": state["model.mm_projector.bias"]})
|
||
|
|
||
|
def chat(self, question):
|
||
|
self.model.eval_string("user: ")
|
||
|
self.model.eval_string(question)
|
||
|
self.model.eval_string("\nassistant: ")
|
||
|
return self.model.generate_with_print()
|
||
|
|
||
|
def chat_with_image(self, image, question):
|
||
|
with torch.no_grad():
|
||
|
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||
|
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
|
||
|
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
||
|
image_feature = select_hidden_state[:, 1:]
|
||
|
embd_image = self.mm_projector(image_feature)
|
||
|
embd_image = embd_image.cpu().numpy()[0]
|
||
|
self.model.eval_string("user: ")
|
||
|
self.model.eval_token(32003-2) # im_start
|
||
|
self.model.eval_float(embd_image.T)
|
||
|
for i in range(image_token_len-embd_image.shape[0]):
|
||
|
self.model.eval_token(32003-3) # im_patch
|
||
|
self.model.eval_token(32003-1) # im_end
|
||
|
self.model.eval_string(question)
|
||
|
self.model.eval_string("\nassistant: ")
|
||
|
return self.model.generate_with_print()
|
||
|
|
||
|
|
||
|
if __name__=="__main__":
|
||
|
# model form liuhaotian/LLaVA-13b-delta-v1-1
|
||
|
a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
|
||
|
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
|
||
|
# Also here can use pytorch_model-00003-of-00003.bin directly.
|
||
|
a.load_projection(os.path.join(
|
||
|
os.path.dirname(__file__) ,
|
||
|
"llava_projetion.pth"))
|
||
|
respose = a.chat_with_image(
|
||
|
Image.open("./media/llama1-logo.png").convert('RGB'),
|
||
|
"what is the text in the picture?")
|
||
|
respose
|
||
|
a.chat("what is the color of it?")
|
||
|
|
||
|
|
||
|
|