From 8fa5f651d633b93bc1f188c355ae3d5250e8d538 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 9 May 2023 20:20:35 -0300 Subject: [PATCH] Style changes --- extensions/multimodal/multimodal_embedder.py | 8 ++++---- extensions/multimodal/pipeline_loader.py | 2 +- extensions/multimodal/script.py | 9 ++++----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/extensions/multimodal/multimodal_embedder.py b/extensions/multimodal/multimodal_embedder.py index 73f07dc2..816e3866 100644 --- a/extensions/multimodal/multimodal_embedder.py +++ b/extensions/multimodal/multimodal_embedder.py @@ -46,12 +46,12 @@ class MultimodalEmbedder: break # found an image, append image start token to the text if match.start() > 0: - parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start())) + parts.append(PromptPart(text=prompt[curr:curr + match.start()] + self.pipeline.image_start())) else: parts.append(PromptPart(text=self.pipeline.image_start())) # append the image parts.append(PromptPart( - text=match.group(0), + text=match.group(0), image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None, is_image=True )) @@ -94,14 +94,14 @@ class MultimodalEmbedder: def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]: """Encode text to token_ids, also truncate the prompt, if necessary. - + The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set such that the context + min_rows don't fit, we can get a prompt which is too long. We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user """ encoded: List[PromptPart] = [] for i, part in enumerate(parts): - encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token'])) + encoded.append(self._encode_single_text(part, i == 0 and state['add_bos_token'])) # truncation: max_len = get_max_prompt_length(state) diff --git a/extensions/multimodal/pipeline_loader.py b/extensions/multimodal/pipeline_loader.py index 24a97022..3ebdb104 100644 --- a/extensions/multimodal/pipeline_loader.py +++ b/extensions/multimodal/pipeline_loader.py @@ -26,7 +26,7 @@ def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]: if shared.args.multimodal_pipeline is not None: for k in pipeline_modules: - if hasattr(pipeline_modules[k], 'get_pipeline'): + if hasattr(pipeline_modules[k], 'get_pipeline'): pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params) if pipeline is not None: return (pipeline, k) diff --git a/extensions/multimodal/script.py b/extensions/multimodal/script.py index 02cf56ba..aeaadffd 100644 --- a/extensions/multimodal/script.py +++ b/extensions/multimodal/script.py @@ -42,14 +42,13 @@ def add_chat_picture(picture, text, visible_text): longest_edge = int(shortest_edge * aspect_ratio) w = shortest_edge if picture.width < picture.height else longest_edge h = shortest_edge if picture.width >= picture.height else longest_edge - picture = picture.resize((w,h)) + picture = picture.resize((w, h)) buffer = BytesIO() picture.save(buffer, format="JPEG") img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') image = f'' - if '' in text: text = text.replace('', image) else: @@ -80,8 +79,8 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds): prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params) logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') return (prompt, - input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), - input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) + input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), + input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) def ui(): @@ -97,7 +96,7 @@ def ui(): [picture_select], None ) - picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None) + picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None) single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) shared.gradio['Generate'].click(lambda: None, None, picture_select) shared.gradio['textbox'].submit(lambda: None, None, picture_select)