mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Fix Continue for LLaVA (#1507)
This commit is contained in:
parent
12212cf6be
commit
04b98a8485
@ -91,7 +91,7 @@ class LLaVAEmbedder:
|
|||||||
# replace the image token with the image patch token in the prompt (each occurrence)
|
# replace the image token with the image patch token in the prompt (each occurrence)
|
||||||
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
|
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
|
||||||
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
|
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
|
||||||
prompt = re.sub(r"<image:([A-Za-z0-9+/=]+)>", replace_token, prompt, 1)
|
prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def _extract_image_features(self, images):
|
def _extract_image_features(self, images):
|
||||||
@ -146,11 +146,11 @@ class LLaVAEmbedder:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def len_in_tokens(text):
|
def len_in_tokens(text):
|
||||||
images = re.findall(r"<image:[A-Za-z0-9+/=]+>", text)
|
images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
|
||||||
image_tokens = 0
|
image_tokens = 0
|
||||||
for _ in images:
|
for _ in images:
|
||||||
image_tokens += 258
|
image_tokens += 258
|
||||||
return len(encode(re.sub(r"<image:[A-Za-z0-9+/=]+>", '', text))[0]) + image_tokens
|
return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
|
||||||
|
|
||||||
|
|
||||||
def add_chat_picture(picture, text, visible_text):
|
def add_chat_picture(picture, text, visible_text):
|
||||||
@ -166,32 +166,21 @@ def add_chat_picture(picture, text, visible_text):
|
|||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
picture.save(buffer, format="JPEG")
|
picture.save(buffer, format="JPEG")
|
||||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||||
visible = f'<img src="data:image/jpeg;base64,{img_str}">'
|
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||||
internal = f'<image:{img_str}>'
|
|
||||||
|
|
||||||
|
if '<image>' in text:
|
||||||
|
text = text.replace('<image>', image)
|
||||||
|
else:
|
||||||
|
text = text + '\n' + image
|
||||||
|
|
||||||
if visible_text == '' or visible_text is None:
|
if visible_text == '' or visible_text is None:
|
||||||
visible_text = text
|
visible_text = text
|
||||||
|
elif '<image>' in visible_text:
|
||||||
if '<image>' in text:
|
visible_text = visible_text.replace('<image>', image)
|
||||||
text = text.replace('<image>', internal)
|
|
||||||
else:
|
else:
|
||||||
text = text + '\n' + internal
|
visible_text = visible_text + '\n' + image
|
||||||
|
|
||||||
if '<image>' in visible_text:
|
|
||||||
visible_text = visible_text.replace('<image>', visible)
|
|
||||||
else:
|
|
||||||
visible_text = visible_text + '\n' + visible
|
|
||||||
|
|
||||||
return text, visible_text
|
|
||||||
|
|
||||||
|
|
||||||
def fix_picture_after_remove_last(text, visible_text):
|
|
||||||
image = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', text)
|
|
||||||
if image is None:
|
|
||||||
return text, visible_text
|
|
||||||
if visible_text is None:
|
|
||||||
visible_text = text
|
|
||||||
text = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', "<image:\\1>", text)
|
|
||||||
return text, visible_text
|
return text, visible_text
|
||||||
|
|
||||||
|
|
||||||
@ -248,7 +237,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||||
global params
|
global params
|
||||||
start_ts = time.time()
|
start_ts = time.time()
|
||||||
image_matches = re.finditer(r"<image:([A-Za-z0-9+/=]+)>", prompt)
|
image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
|
||||||
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
|
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
|
||||||
|
|
||||||
if len(images) == 0:
|
if len(images) == 0:
|
||||||
@ -276,4 +265,3 @@ def ui():
|
|||||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
||||||
shared.gradio['Remove last'].click(lambda: input_hijack.update({"state": True, "value": fix_picture_after_remove_last}), None, None)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user