Merge pull request #76 from SillyLossy/main

Use BLIP to send a picture to model
This commit is contained in:
oobabooga 2023-02-14 23:57:44 -03:00 committed by GitHub
commit d4d90a8000
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 9 deletions

14
modules/bot_picture.py Normal file
View File

@ -0,0 +1,14 @@
import requests
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration
from transformers import BlipProcessor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16).to("cuda")
# raw_image = Image.open('/tmp/istockphoto-470604022-612x612.jpg').convert('RGB')
def caption_image(raw_image):
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True)

View File

@ -217,6 +217,12 @@ def generate_chat_html(history, name1, name2, character):
.body { .body {
} }
.body img {
max-width: 300px;
max-height: 300px;
border-radius: 20px;
}
""" """
output = '' output = ''

View File

@ -23,6 +23,7 @@ from tqdm import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer
from io import BytesIO
from modules.html_generator import * from modules.html_generator import *
from modules.stopping_criteria import _SentinelTokenStoppingCriteria from modules.stopping_criteria import _SentinelTokenStoppingCriteria
@ -53,6 +54,7 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes.')
args = parser.parse_args() args = parser.parse_args()
if (args.chat or args.cai_chat) and not args.no_stream: if (args.chat or args.cai_chat) and not args.no_stream:
@ -97,6 +99,9 @@ if args.deepspeed:
ds_config = generate_ds_config(args.bf16, 1 * world_size, args.nvme_offload_dir) ds_config = generate_ds_config(args.bf16, 1 * world_size, args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
if args.picture and (args.cai_chat or args.chat):
import modules.bot_picture as bot_picture
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
@ -561,8 +566,12 @@ def extract_message_from_reply(question, reply, current, other, check, extension
return reply, next_character_found, substring_found return reply, next_character_found, substring_found
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
original_text = text if args.picture and picture is not None:
text, visible_text = generate_chat_picture(picture, name1, name2)
else:
visible_text = text
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
history['internal'].append(['', '']) history['internal'].append(['', ''])
@ -571,14 +580,14 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True) reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
history['internal'][-1] = [text, reply] history['internal'][-1] = [text, reply]
history['visible'][-1] = [original_text, apply_extensions(reply, "output")] history['visible'][-1] = [visible_text, apply_extensions(reply, "output")]
if not substring_found: if not substring_found:
yield history['visible'] yield history['visible']
if next_character_found: if next_character_found:
break break
yield history['visible'] yield history['visible']
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True) question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
eos_token = '\n' if check else None eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
@ -589,19 +598,19 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
break break
yield apply_extensions(reply, "output") yield apply_extensions(reply, "output")
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
yield generate_chat_html(_history, name1, name2, character) yield generate_chat_html(_history, name1, name2, character)
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None):
last = history['visible'].pop() last = history['visible'].pop()
history['internal'].pop() history['internal'].pop()
text = last[0] text = last[0]
if args.cai_chat: if args.cai_chat:
for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
yield i yield i
else: else:
for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
yield i yield i
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
@ -791,6 +800,14 @@ def upload_your_profile_picture(img):
img.save(Path(f'img_me.png')) img.save(Path(f'img_me.png'))
print(f'Profile picture saved to "img_me.png"') print(f'Profile picture saved to "img_me.png"')
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*'
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
return text, visible_text
# Global variables # Global variables
available_models = get_available_models() available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
@ -861,6 +878,9 @@ if args.chat or args.cai_chat:
with gr.Row(): with gr.Row():
buttons["Send last reply to input"] = gr.Button("Send last reply to input") buttons["Send last reply to input"] = gr.Button("Send last reply to input")
buttons["Replace last reply"] = gr.Button("Replace last reply") buttons["Replace last reply"] = gr.Button("Replace last reply")
if args.picture:
with gr.Row():
picture_select = gr.Image(label="Send a picture", type='pil')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -907,12 +927,18 @@ if args.chat or args.cai_chat:
create_extensions_block() create_extensions_block()
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider] input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider]
if args.picture:
input_params.append(picture_select)
if args.cai_chat: if args.cai_chat:
gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen"))
gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream)) gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream))
if args.picture:
picture_select.upload(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream)
else: else:
gen_events.append(buttons["Generate"].click(chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) gen_events.append(buttons["Generate"].click(chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen"))
gen_events.append(textbox.submit(chatbot_wrapper, input_params, display, show_progress=args.no_stream)) gen_events.append(textbox.submit(chatbot_wrapper, input_params, display, show_progress=args.no_stream))
if args.picture:
picture_select.upload(chatbot_wrapper, input_params, display, show_progress=args.no_stream)
gen_events.append(buttons["Regenerate"].click(regenerate_wrapper, input_params, display, show_progress=args.no_stream)) gen_events.append(buttons["Regenerate"].click(regenerate_wrapper, input_params, display, show_progress=args.no_stream))
gen_events.append(buttons["Impersonate"].click(impersonate_wrapper, input_params, textbox, show_progress=args.no_stream)) gen_events.append(buttons["Impersonate"].click(impersonate_wrapper, input_params, textbox, show_progress=args.no_stream))
@ -925,11 +951,14 @@ if args.chat or args.cai_chat:
buttons["Upload character"].click(upload_character, [upload_char, upload_img], [character_menu]) buttons["Upload character"].click(upload_character, [upload_char, upload_img], [character_menu])
for i in ["Generate", "Regenerate", "Replace last reply"]: for i in ["Generate", "Regenerate", "Replace last reply"]:
buttons[i].click(lambda x: "", textbox, textbox, show_progress=False) buttons[i].click(lambda x: "", textbox, textbox, show_progress=False)
textbox.submit(lambda x: "", textbox, textbox, show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display]) character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display])
upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu]) upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu])
upload.upload(load_history, [upload, name1, name2], []) upload.upload(load_history, [upload, name1, name2], [])
upload_img_me.upload(upload_your_profile_picture, [upload_img_me], []) upload_img_me.upload(upload_your_profile_picture, [upload_img_me], [])
if args.picture:
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
if args.cai_chat: if args.cai_chat:
upload.upload(redraw_html, [name1, name2], [display]) upload.upload(redraw_html, [name1, name2], [display])