import sys import os sys.path.insert(0, os.path.dirname(__file__)) from embd_input import MyModel import numpy as np from torch import nn import torch from PIL import Image minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4") sys.path.insert(0, minigpt4_path) from minigpt4.models.blip2 import Blip2Base from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor class MiniGPT4(Blip2Base): """ MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4 """ def __init__(self, args, vit_model="eva_clip_g", q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", img_size=224, drop_path_rate=0, use_grad_checkpoint=False, vit_precision="fp32", freeze_vit=True, freeze_qformer=True, num_query_token=32, llama_model="", prompt_path="", prompt_template="", max_txt_len=32, end_sym='\n', low_resource=False, # use 8 bit and put vit in cpu device_8bit=0 ): super().__init__() self.img_size = img_size self.low_resource = low_resource self.preprocessor = Blip2ImageEvalProcessor(img_size) print('Loading VIT') self.visual_encoder, self.ln_vision = self.init_vision_encoder( vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision ) print('Loading VIT Done') print('Loading Q-Former') self.Qformer, self.query_tokens = self.init_Qformer( num_query_token, self.visual_encoder.num_features ) self.Qformer.cls = None self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None self.load_from_pretrained(url_or_filename=q_former_model) print('Loading Q-Former Done') self.llama_proj = nn.Linear( self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size ) self.max_txt_len = max_txt_len self.end_sym = end_sym self.model = MyModel(["main", *args]) # system promt self.model.eval_string("Give the following image: ImageContent. " "You will be able to see the image once I provide it to you. Please answer my questions." "###") def encode_img(self, image): image = self.preprocessor(image) image = image.unsqueeze(0) device = image.device if self.low_resource: self.vit_to_cpu() image = image.to("cpu") with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) inputs_llama = self.llama_proj(query_output.last_hidden_state) # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) return inputs_llama def load_projection(self, path): state = torch.load(path)["model"] self.llama_proj.load_state_dict({ "weight": state["llama_proj.weight"], "bias": state["llama_proj.bias"]}) def chat(self, question): self.model.eval_string("Human: ") self.model.eval_string(question) self.model.eval_string("\n### Assistant:") return self.model.generate_with_print(end="###") def chat_with_image(self, image, question): with torch.no_grad(): embd_image = self.encode_img(image) embd_image = embd_image.cpu().numpy()[0] self.model.eval_string("Human: ") self.model.eval_float(embd_image.T) self.model.eval_string(" ") self.model.eval_string(question) self.model.eval_string("\n### Assistant:") return self.model.generate_with_print(end="###") if __name__=="__main__": a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"]) a.load_projection(os.path.join( os.path.dirname(__file__) , "pretrained_minigpt4.pth")) respose = a.chat_with_image( Image.open("./media/llama1-logo.png").convert('RGB'), "what is the text in the picture?") a.chat("what is the color of it?")