Use BLIP to send a picture to model

This commit is contained in:
SillyLossy 2023-02-15 01:38:21 +02:00
parent 79d3a524f2
commit a7d98f494a
4 changed files with 58 additions and 18 deletions

9
modules/bot_picture.py Normal file
View File

@ -0,0 +1,9 @@
from nataili_blip.model_manager import BlipModelManager
from nataili_blip.caption import Caption
def load_model():
model_name = "BLIP"
mm = BlipModelManager()
mm.download_model(model_name)
mm.load_blip(model_name)
return Caption(mm.loaded_models[model_name]["model"], mm.loaded_models[model_name]["device"])

View File

@ -217,6 +217,12 @@ def generate_chat_html(history, name1, name2, character):
.body { .body {
} }
.body img {
max-width: 300px;
max-height: 300px;
border-radius: 20px;
}
""" """
output = '' output = ''

View File

@ -4,4 +4,5 @@ bitsandbytes==0.37.0
gradio==3.15.0 gradio==3.15.0
numpy numpy
safetensors==0.2.8 safetensors==0.2.8
nataili_blip
git+https://github.com/huggingface/transformers git+https://github.com/huggingface/transformers

View File

@ -23,6 +23,7 @@ from tqdm import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer
from io import BytesIO
from modules.html_generator import * from modules.html_generator import *
from modules.stopping_criteria import _SentinelTokenStoppingCriteria from modules.stopping_criteria import _SentinelTokenStoppingCriteria
@ -53,6 +54,7 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes.')
args = parser.parse_args() args = parser.parse_args()
if (args.chat or args.cai_chat) and not args.no_stream: if (args.chat or args.cai_chat) and not args.no_stream:
@ -97,6 +99,10 @@ if args.deepspeed:
ds_config = generate_ds_config(args.bf16, 1 * world_size, args.nvme_offload_dir) ds_config = generate_ds_config(args.bf16, 1 * world_size, args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
if args.picture and (args.cai_chat or args.chat):
import modules.bot_picture as bot_picture
blip = bot_picture.load_model()
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
@ -561,8 +567,12 @@ def extract_message_from_reply(question, reply, current, other, check, extension
return reply, next_character_found, substring_found return reply, next_character_found, substring_found
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
original_text = text original_text = text
if args.picture and picture is not None:
text, original_text = generate_chat_picture(picture, name1, name2)
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
history['internal'].append(['', '']) history['internal'].append(['', ''])
@ -573,12 +583,12 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
history['internal'][-1] = [text, reply] history['internal'][-1] = [text, reply]
history['visible'][-1] = [original_text, apply_extensions(reply, "output")] history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
if not substring_found: if not substring_found:
yield history['visible'] yield history['visible'], None
if next_character_found: if next_character_found:
break break
yield history['visible'] yield history['visible'], None
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True) question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
eos_token = '\n' if check else None eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
@ -589,20 +599,20 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
break break
yield apply_extensions(reply, "output") yield apply_extensions(reply, "output")
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): for _history, _ in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
yield generate_chat_html(_history, name1, name2, character) yield generate_chat_html(_history, name1, name2, character), None
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
last = history['visible'].pop() last = history['visible'].pop()
history['internal'].pop() history['internal'].pop()
text = last[0] text = last[0]
if args.cai_chat: if args.cai_chat:
for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): for i, _ in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
yield i yield i, None
else: else:
for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): for i, _ in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture):
yield i yield i, None
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
@ -791,6 +801,14 @@ def upload_your_profile_picture(img):
img.save(Path(f'img_me.png')) img.save(Path(f'img_me.png'))
print(f'Profile picture saved to "img_me.png"') print(f'Profile picture saved to "img_me.png"')
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{blip(picture)}"*'
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
original_text = f'<img src="data:image/jpeg;base64,{img_str}">'
return text, original_text
# Global variables # Global variables
available_models = get_available_models() available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
@ -861,6 +879,9 @@ if args.chat or args.cai_chat:
with gr.Row(): with gr.Row():
buttons["Send last reply to input"] = gr.Button("Send last reply to input") buttons["Send last reply to input"] = gr.Button("Send last reply to input")
buttons["Replace last reply"] = gr.Button("Replace last reply") buttons["Replace last reply"] = gr.Button("Replace last reply")
if args.picture:
with gr.Row():
picture_select = gr.Image(label="Send a picture", type='pil', display_label=True)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -906,14 +927,17 @@ if args.chat or args.cai_chat:
if args.extensions is not None: if args.extensions is not None:
create_extensions_block() create_extensions_block()
input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider] input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider, picture_select]
output_params = [display, picture_select]
if args.cai_chat: if args.cai_chat:
gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, output_params, show_progress=args.no_stream, api_name="textgen"))
gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream)) gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, output_params, show_progress=args.no_stream))
picture_select.upload(cai_chatbot_wrapper, input_params, output_params, show_progress=args.no_stream)
else: else:
gen_events.append(buttons["Generate"].click(chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) gen_events.append(buttons["Generate"].click(chatbot_wrapper, input_params, output_params, show_progress=args.no_stream, api_name="textgen"))
gen_events.append(textbox.submit(chatbot_wrapper, input_params, display, show_progress=args.no_stream)) gen_events.append(textbox.submit(chatbot_wrapper, input_params, output_params, show_progress=args.no_stream))
gen_events.append(buttons["Regenerate"].click(regenerate_wrapper, input_params, display, show_progress=args.no_stream)) picture_select.upload(chatbot_wrapper, input_params, output_params, show_progress=args.no_stream)
gen_events.append(buttons["Regenerate"].click(regenerate_wrapper, input_params, output_params, show_progress=args.no_stream))
gen_events.append(buttons["Impersonate"].click(impersonate_wrapper, input_params, textbox, show_progress=args.no_stream)) gen_events.append(buttons["Impersonate"].click(impersonate_wrapper, input_params, textbox, show_progress=args.no_stream))
buttons["Send last reply to input"].click(send_last_reply_to_input, [], textbox, show_progress=args.no_stream) buttons["Send last reply to input"].click(send_last_reply_to_input, [], textbox, show_progress=args.no_stream)