diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 7ee604ee..14e9b641 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -2,13 +2,11 @@ import base64 from io import BytesIO import gradio as gr +import torch +from transformers import BlipForConditionalGeneration, BlipProcessor import modules.chat as chat import modules.shared as shared -from modules.bot_picture import caption_image - -params = { -} # If 'state' is True, will hijack the next chat generation with # custom input text @@ -17,6 +15,14 @@ input_hijack = { 'value': ["", ""] } +processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") +model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") + +def caption_image(raw_image): + inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) + out = model.generate(**inputs, max_new_tokens=100) + return processor.decode(out[0], skip_special_tokens=True) + def generate_chat_picture(picture, name1, name2): text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' buffer = BytesIO() diff --git a/modules/bot_picture.py b/modules/bot_picture.py deleted file mode 100644 index dd4d73eb..00000000 --- a/modules/bot_picture.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch -from transformers import BlipForConditionalGeneration, BlipProcessor - -processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") -model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") - -def caption_image(raw_image): - inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) - out = model.generate(**inputs, max_new_tokens=100) - return processor.decode(out[0], skip_special_tokens=True) diff --git a/modules/extensions.py b/modules/extensions.py index c0da496a..c8de8a7b 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -33,10 +33,11 @@ def apply_extensions(text, typ): def create_extensions_block(): # Updating the default values for extension, name in iterator(): - for param in extension.params: - _id = f"{name}-{param}" - if _id in shared.settings: - extension.params[param] = shared.settings[_id] + if hasattr(extension, 'params'): + for param in extension.params: + _id = f"{name}-{param}" + if _id in shared.settings: + extension.params[param] = shared.settings[_id] # Creating the extension ui elements for extension, name in iterator():