From 5389fce8e11a6018c44cfea4b29a1a0216ab5687 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com>
Date: Wed, 22 Mar 2023 07:47:54 +0300
Subject: [PATCH 1/4] Extensions performance & memory optimisations
Reworked remove_surrounded_chars() to use regular expression ( https://regexr.com/7alb5 ) instead of repeated string concatenations for elevenlab_tts, silero_tts, sd_api_pictures. This should be both faster and more robust in handling asterisks.
Reduced the memory footprint of send_pictures and sd_api_pictures by scaling the images in the chat to 300 pixels max-side wise. (The user already has the original in case of the sent picture and there's an option to save the SD generation).
This should fix history growing annoyingly large with multiple pictures present
---
extensions/elevenlabs_tts/script.py | 14 ++++++--------
extensions/sd_api_pictures/script.py | 26 ++++++++++++++++----------
extensions/send_pictures/script.py | 13 ++++++++++++-
extensions/silero_tts/script.py | 13 +++++--------
4 files changed, 39 insertions(+), 27 deletions(-)
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
index 7339cc73..cee64c06 100644
--- a/extensions/elevenlabs_tts/script.py
+++ b/extensions/elevenlabs_tts/script.py
@@ -4,6 +4,8 @@ import gradio as gr
from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path
+import re
+
import modules.shared as shared
params = {
@@ -52,14 +54,10 @@ def refresh_voices():
return
def remove_surrounded_chars(string):
- new_string = ""
- in_star = False
- for char in string:
- if char == '*':
- in_star = not in_star
- elif not in_star:
- new_string += char
- return new_string
+ # regexp is way faster than repeated string concatenation!
+ # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+ # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+ return re.sub('\*[^\*]*?(\*|$)','',string)
def input_modifier(string):
"""
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index b9fba2b9..03d3c784 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -1,5 +1,6 @@
import base64
import io
+import re
from pathlib import Path
import gradio as gr
@@ -31,14 +32,10 @@ picture_response = False # specifies if the next model response should appear as
pic_id = 0
def remove_surrounded_chars(string):
- new_string = ""
- in_star = False
- for char in string:
- if char == '*':
- in_star = not in_star
- elif not in_star:
- new_string += char
- return new_string
+ # regexp is way faster than repeated string concatenation!
+ # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+ # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+ return re.sub('\*[^\*]*?(\*|$)','',string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string):
@@ -54,6 +51,8 @@ def input_modifier(string):
mediums = ['image', 'pic', 'picture', 'photo']
subjects = ['yourself', 'own']
lowstr = string.lower()
+
+ # TODO: refactor out to separate handler and also replace detection with a regexp
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
picture_response = True
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
@@ -91,8 +90,15 @@ def get_SD_pictures(description):
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
image.save(output_file.as_posix())
pic_id += 1
- # lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
- newsize = (300, 300)
+ # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
+ width, height = image.size
+ if (width > 300):
+ height = int(height * (300 / width))
+ width = 300
+ elif (height > 300):
+ width = int(width * (300 / height))
+ height = 300
+ newsize = (width, height)
image = image.resize(newsize, Image.LANCZOS)
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index b0c35632..05ceed8b 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -4,6 +4,7 @@ from io import BytesIO
import gradio as gr
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
+from PIL import Image
import modules.chat as chat
import modules.shared as shared
@@ -25,10 +26,20 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
+ # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
+ width, height = picture.size
+ if (width > 300):
+ height = int(height * (300 / width))
+ width = 300
+ elif (height > 300):
+ width = int(width * (300 / height))
+ height = 300
+ newsize = (width, height)
+ picture = picture.resize(newsize, Image.LANCZOS)
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
- visible_text = f''
+ visible_text = f''
return text, visible_text
def ui():
diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py
index f611dc27..447c7033 100644
--- a/extensions/silero_tts/script.py
+++ b/extensions/silero_tts/script.py
@@ -3,6 +3,7 @@ from pathlib import Path
import gradio as gr
import torch
+import re
import modules.chat as chat
import modules.shared as shared
@@ -46,14 +47,10 @@ def load_model():
model = load_model()
def remove_surrounded_chars(string):
- new_string = ""
- in_star = False
- for char in string:
- if char == '*':
- in_star = not in_star
- elif not in_star:
- new_string += char
- return new_string
+ # regexp is way faster than repeated string concatenation!
+ # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+ # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+ return re.sub('\*[^\*]*?(\*|$)','',string)
def remove_tts_from_history(name1, name2):
for i, entry in enumerate(shared.history['internal']):
From 104212529ff8c6b41baf12bfd4673fe782ed04d7 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 22 Mar 2023 15:55:03 -0300
Subject: [PATCH 2/4] Minor changes
---
extensions/elevenlabs_tts/script.py | 9 +++------
extensions/sd_api_pictures/script.py | 8 +++-----
extensions/send_pictures/script.py | 9 ++++-----
extensions/silero_tts/script.py | 8 +++-----
4 files changed, 13 insertions(+), 21 deletions(-)
diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py
index cee64c06..2e8b184f 100644
--- a/extensions/elevenlabs_tts/script.py
+++ b/extensions/elevenlabs_tts/script.py
@@ -1,13 +1,11 @@
+import re
from pathlib import Path
import gradio as gr
+import modules.shared as shared
from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path
-import re
-
-import modules.shared as shared
-
params = {
'activate': True,
'api_key': '12345',
@@ -54,7 +52,6 @@ def refresh_voices():
return
def remove_surrounded_chars(string):
- # regexp is way faster than repeated string concatenation!
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
@@ -113,4 +110,4 @@ def ui():
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
connect.click(check_valid_api, [], connection_status)
- connect.click(refresh_voices, [], voice)
+ connect.click(refresh_voices, [], voice)
\ No newline at end of file
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index 03d3c784..1f6ba2d2 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -4,13 +4,12 @@ import re
from pathlib import Path
import gradio as gr
+import modules.chat as chat
+import modules.shared as shared
import requests
import torch
from PIL import Image
-import modules.chat as chat
-import modules.shared as shared
-
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
@@ -32,7 +31,6 @@ picture_response = False # specifies if the next model response should appear as
pic_id = 0
def remove_surrounded_chars(string):
- # regexp is way faster than repeated string concatenation!
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
@@ -186,4 +184,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
- generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+ generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
\ No newline at end of file
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index 05ceed8b..46393e6c 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -2,12 +2,11 @@ import base64
from io import BytesIO
import gradio as gr
-import torch
-from transformers import BlipForConditionalGeneration, BlipProcessor
-from PIL import Image
-
import modules.chat as chat
import modules.shared as shared
+import torch
+from PIL import Image
+from transformers import BlipForConditionalGeneration, BlipProcessor
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
@@ -54,4 +53,4 @@ def ui():
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field
- picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
+ picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
\ No newline at end of file
diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py
index 447c7033..a81a5da1 100644
--- a/extensions/silero_tts/script.py
+++ b/extensions/silero_tts/script.py
@@ -1,12 +1,11 @@
+import re
import time
from pathlib import Path
import gradio as gr
-import torch
-import re
-
import modules.chat as chat
import modules.shared as shared
+import torch
torch._C._jit_set_profiling_mode(False)
@@ -47,7 +46,6 @@ def load_model():
model = load_model()
def remove_surrounded_chars(string):
- # regexp is way faster than repeated string concatenation!
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
@@ -163,4 +161,4 @@ def ui():
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
- v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
+ v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
\ No newline at end of file
From 0abff499e2e40fb308b406b82ad3f850683720e3 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 22 Mar 2023 16:03:05 -0300
Subject: [PATCH 3/4] Use image.thumbnail
---
extensions/sd_api_pictures/script.py | 10 +---------
extensions/send_pictures/script.py | 10 +---------
2 files changed, 2 insertions(+), 18 deletions(-)
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index 1f6ba2d2..cc85f3b3 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -89,15 +89,7 @@ def get_SD_pictures(description):
image.save(output_file.as_posix())
pic_id += 1
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
- width, height = image.size
- if (width > 300):
- height = int(height * (300 / width))
- width = 300
- elif (height > 300):
- width = int(width * (300 / height))
- height = 300
- newsize = (width, height)
- image = image.resize(newsize, Image.LANCZOS)
+ image.thumbnail((300, 300))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index 46393e6c..196c7d53 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -26,15 +26,7 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
- width, height = picture.size
- if (width > 300):
- height = int(height * (300 / width))
- width = 300
- elif (height > 300):
- width = int(width * (300 / height))
- height = 300
- newsize = (width, height)
- picture = picture.resize(newsize, Image.LANCZOS)
+ image.thumbnail((300, 300))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
From bfb1be2820417fb5d2e843d9e164862b8962446d Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 22 Mar 2023 16:09:48 -0300
Subject: [PATCH 4/4] Minor fix
---
extensions/send_pictures/script.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index 196c7d53..556a88e5 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -26,7 +26,7 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
- image.thumbnail((300, 300))
+ picture.thumbnail((300, 300))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
@@ -45,4 +45,4 @@ def ui():
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field
- picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
\ No newline at end of file
+ picture_select.upload(lambda : None, [], [picture_select], show_progress=False)