mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Style changes
This commit is contained in:
parent
e9e75a9ec7
commit
8fa5f651d6
@ -46,12 +46,12 @@ class MultimodalEmbedder:
|
|||||||
break
|
break
|
||||||
# found an image, append image start token to the text
|
# found an image, append image start token to the text
|
||||||
if match.start() > 0:
|
if match.start() > 0:
|
||||||
parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start()))
|
parts.append(PromptPart(text=prompt[curr:curr + match.start()] + self.pipeline.image_start()))
|
||||||
else:
|
else:
|
||||||
parts.append(PromptPart(text=self.pipeline.image_start()))
|
parts.append(PromptPart(text=self.pipeline.image_start()))
|
||||||
# append the image
|
# append the image
|
||||||
parts.append(PromptPart(
|
parts.append(PromptPart(
|
||||||
text=match.group(0),
|
text=match.group(0),
|
||||||
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
|
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
|
||||||
is_image=True
|
is_image=True
|
||||||
))
|
))
|
||||||
@ -94,14 +94,14 @@ class MultimodalEmbedder:
|
|||||||
|
|
||||||
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
|
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
|
||||||
"""Encode text to token_ids, also truncate the prompt, if necessary.
|
"""Encode text to token_ids, also truncate the prompt, if necessary.
|
||||||
|
|
||||||
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
|
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
|
||||||
such that the context + min_rows don't fit, we can get a prompt which is too long.
|
such that the context + min_rows don't fit, we can get a prompt which is too long.
|
||||||
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
|
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
|
||||||
"""
|
"""
|
||||||
encoded: List[PromptPart] = []
|
encoded: List[PromptPart] = []
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token']))
|
encoded.append(self._encode_single_text(part, i == 0 and state['add_bos_token']))
|
||||||
|
|
||||||
# truncation:
|
# truncation:
|
||||||
max_len = get_max_prompt_length(state)
|
max_len = get_max_prompt_length(state)
|
||||||
|
@ -26,7 +26,7 @@ def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
|
|||||||
|
|
||||||
if shared.args.multimodal_pipeline is not None:
|
if shared.args.multimodal_pipeline is not None:
|
||||||
for k in pipeline_modules:
|
for k in pipeline_modules:
|
||||||
if hasattr(pipeline_modules[k], 'get_pipeline'):
|
if hasattr(pipeline_modules[k], 'get_pipeline'):
|
||||||
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
|
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
return (pipeline, k)
|
return (pipeline, k)
|
||||||
|
@ -42,14 +42,13 @@ def add_chat_picture(picture, text, visible_text):
|
|||||||
longest_edge = int(shortest_edge * aspect_ratio)
|
longest_edge = int(shortest_edge * aspect_ratio)
|
||||||
w = shortest_edge if picture.width < picture.height else longest_edge
|
w = shortest_edge if picture.width < picture.height else longest_edge
|
||||||
h = shortest_edge if picture.width >= picture.height else longest_edge
|
h = shortest_edge if picture.width >= picture.height else longest_edge
|
||||||
picture = picture.resize((w,h))
|
picture = picture.resize((w, h))
|
||||||
|
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
picture.save(buffer, format="JPEG")
|
picture.save(buffer, format="JPEG")
|
||||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||||
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||||
|
|
||||||
|
|
||||||
if '<image>' in text:
|
if '<image>' in text:
|
||||||
text = text.replace('<image>', image)
|
text = text.replace('<image>', image)
|
||||||
else:
|
else:
|
||||||
@ -80,8 +79,8 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
|||||||
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
||||||
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||||
return (prompt,
|
return (prompt,
|
||||||
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
||||||
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
@ -97,7 +96,7 @@ def ui():
|
|||||||
[picture_select],
|
[picture_select],
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
|
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
|
||||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
||||||
|
Loading…
Reference in New Issue
Block a user