Move bot_picture.py inside the extension

This commit is contained in:
oobabooga 2023-02-25 03:00:19 -03:00
parent 5ac24b019e
commit 91f5852245
3 changed files with 15 additions and 18 deletions

View File

@ -2,13 +2,11 @@ import base64
from io import BytesIO from io import BytesIO
import gradio as gr import gradio as gr
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
import modules.chat as chat import modules.chat as chat
import modules.shared as shared import modules.shared as shared
from modules.bot_picture import caption_image
params = {
}
# If 'state' is True, will hijack the next chat generation with # If 'state' is True, will hijack the next chat generation with
# custom input text # custom input text
@ -17,6 +15,14 @@ input_hijack = {
'value': ["", ""] 'value': ["", ""]
} }
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
def caption_image(raw_image):
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
out = model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True)
def generate_chat_picture(picture, name1, name2): def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
buffer = BytesIO() buffer = BytesIO()

View File

@ -1,10 +0,0 @@
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
def caption_image(raw_image):
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
out = model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True)

View File

@ -33,6 +33,7 @@ def apply_extensions(text, typ):
def create_extensions_block(): def create_extensions_block():
# Updating the default values # Updating the default values
for extension, name in iterator(): for extension, name in iterator():
if hasattr(extension, 'params'):
for param in extension.params: for param in extension.params:
_id = f"{name}-{param}" _id = f"{name}-{param}"
if _id in shared.settings: if _id in shared.settings: