From 04b98a8485c93f3a6356947d7f500ece892e5931 Mon Sep 17 00:00:00 2001 From: Wojtab Date: Mon, 24 Apr 2023 03:58:15 +0200 Subject: [PATCH] Fix Continue for LLaVA (#1507) --- extensions/llava/script.py | 40 +++++++++++++------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/extensions/llava/script.py b/extensions/llava/script.py index a2ad34d5..d48e35fa 100644 --- a/extensions/llava/script.py +++ b/extensions/llava/script.py @@ -91,7 +91,7 @@ class LLaVAEmbedder: # replace the image token with the image patch token in the prompt (each occurrence) replace_token = LLaVAEmbedder.IM_PATCH.token * 256 replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token - prompt = re.sub(r"", replace_token, prompt, 1) + prompt = re.sub(r'', replace_token, prompt, 1) return prompt def _extract_image_features(self, images): @@ -146,11 +146,11 @@ class LLaVAEmbedder: @staticmethod def len_in_tokens(text): - images = re.findall(r"", text) + images = re.findall(r'', text) image_tokens = 0 for _ in images: image_tokens += 258 - return len(encode(re.sub(r"", '', text))[0]) + image_tokens + return len(encode(re.sub(r'', '', text))[0]) + image_tokens def add_chat_picture(picture, text, visible_text): @@ -166,32 +166,21 @@ def add_chat_picture(picture, text, visible_text): buffer = BytesIO() picture.save(buffer, format="JPEG") img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') - visible = f'' - internal = f'' + image = f'' + + + if '' in text: + text = text.replace('', image) + else: + text = text + '\n' + image if visible_text == '' or visible_text is None: visible_text = text - - if '' in text: - text = text.replace('', internal) + elif '' in visible_text: + visible_text = visible_text.replace('', image) else: - text = text + '\n' + internal + visible_text = visible_text + '\n' + image - if '' in visible_text: - visible_text = visible_text.replace('', visible) - else: - visible_text = visible_text + '\n' + visible - - return text, visible_text - - -def fix_picture_after_remove_last(text, visible_text): - image = re.search(r'', text) - if image is None: - return text, visible_text - if visible_text is None: - visible_text = text - text = re.sub(r'', "", text) return text, visible_text @@ -248,7 +237,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): def tokenizer_modifier(state, prompt, input_ids, input_embeds): global params start_ts = time.time() - image_matches = re.finditer(r"", prompt) + image_matches = re.finditer(r'', prompt) images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches] if len(images) == 0: @@ -276,4 +265,3 @@ def ui(): single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) shared.gradio['Generate'].click(lambda: None, None, picture_select) shared.gradio['textbox'].submit(lambda: None, None, picture_select) - shared.gradio['Remove last'].click(lambda: input_hijack.update({"state": True, "value": fix_picture_after_remove_last}), None, None)