Fix Continue for LLaVA (#1507)

This commit is contained in:
Wojtab 2023-04-24 03:58:15 +02:00 committed by GitHub
parent 12212cf6be
commit 04b98a8485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -91,7 +91,7 @@ class LLaVAEmbedder:
# replace the image token with the image patch token in the prompt (each occurrence) # replace the image token with the image patch token in the prompt (each occurrence)
replace_token = LLaVAEmbedder.IM_PATCH.token * 256 replace_token = LLaVAEmbedder.IM_PATCH.token * 256
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
prompt = re.sub(r"<image:([A-Za-z0-9+/=]+)>", replace_token, prompt, 1) prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
return prompt return prompt
def _extract_image_features(self, images): def _extract_image_features(self, images):
@ -146,11 +146,11 @@ class LLaVAEmbedder:
@staticmethod @staticmethod
def len_in_tokens(text): def len_in_tokens(text):
images = re.findall(r"<image:[A-Za-z0-9+/=]+>", text) images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
image_tokens = 0 image_tokens = 0
for _ in images: for _ in images:
image_tokens += 258 image_tokens += 258
return len(encode(re.sub(r"<image:[A-Za-z0-9+/=]+>", '', text))[0]) + image_tokens return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
def add_chat_picture(picture, text, visible_text): def add_chat_picture(picture, text, visible_text):
@ -166,32 +166,21 @@ def add_chat_picture(picture, text, visible_text):
buffer = BytesIO() buffer = BytesIO()
picture.save(buffer, format="JPEG") picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible = f'<img src="data:image/jpeg;base64,{img_str}">' image = f'<img src="data:image/jpeg;base64,{img_str}">'
internal = f'<image:{img_str}>'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None: if visible_text == '' or visible_text is None:
visible_text = text visible_text = text
elif '<image>' in visible_text:
if '<image>' in text: visible_text = visible_text.replace('<image>', image)
text = text.replace('<image>', internal)
else: else:
text = text + '\n' + internal visible_text = visible_text + '\n' + image
if '<image>' in visible_text:
visible_text = visible_text.replace('<image>', visible)
else:
visible_text = visible_text + '\n' + visible
return text, visible_text
def fix_picture_after_remove_last(text, visible_text):
image = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', text)
if image is None:
return text, visible_text
if visible_text is None:
visible_text = text
text = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', "<image:\\1>", text)
return text, visible_text return text, visible_text
@ -248,7 +237,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
def tokenizer_modifier(state, prompt, input_ids, input_embeds): def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params global params
start_ts = time.time() start_ts = time.time()
image_matches = re.finditer(r"<image:([A-Za-z0-9+/=]+)>", prompt) image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches] images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
if len(images) == 0: if len(images) == 0:
@ -276,4 +265,3 @@ def ui():
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select) shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select) shared.gradio['textbox'].submit(lambda: None, None, picture_select)
shared.gradio['Remove last'].click(lambda: input_hijack.update({"state": True, "value": fix_picture_after_remove_last}), None, None)