mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Extensions performance & memory optimisations
Reworked remove_surrounded_chars() to use regular expression ( https://regexr.com/7alb5 ) instead of repeated string concatenations for elevenlab_tts, silero_tts, sd_api_pictures. This should be both faster and more robust in handling asterisks. Reduced the memory footprint of send_pictures and sd_api_pictures by scaling the images in the chat to 300 pixels max-side wise. (The user already has the original in case of the sent picture and there's an option to save the SD generation). This should fix history growing annoyingly large with multiple pictures present
This commit is contained in:
parent
45b7e53565
commit
5389fce8e1
@ -4,6 +4,8 @@ import gradio as gr
|
||||
from elevenlabslib import ElevenLabsUser
|
||||
from elevenlabslib.helpers import save_bytes_to_path
|
||||
|
||||
import re
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
params = {
|
||||
@ -52,14 +54,10 @@ def refresh_voices():
|
||||
return
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
new_string = ""
|
||||
in_star = False
|
||||
for char in string:
|
||||
if char == '*':
|
||||
in_star = not in_star
|
||||
elif not in_star:
|
||||
new_string += char
|
||||
return new_string
|
||||
# regexp is way faster than repeated string concatenation!
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
|
@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
@ -31,14 +32,10 @@ picture_response = False # specifies if the next model response should appear as
|
||||
pic_id = 0
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
new_string = ""
|
||||
in_star = False
|
||||
for char in string:
|
||||
if char == '*':
|
||||
in_star = not in_star
|
||||
elif not in_star:
|
||||
new_string += char
|
||||
return new_string
|
||||
# regexp is way faster than repeated string concatenation!
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
||||
def input_modifier(string):
|
||||
@ -54,6 +51,8 @@ def input_modifier(string):
|
||||
mediums = ['image', 'pic', 'picture', 'photo']
|
||||
subjects = ['yourself', 'own']
|
||||
lowstr = string.lower()
|
||||
|
||||
# TODO: refactor out to separate handler and also replace detection with a regexp
|
||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
||||
picture_response = True
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||
@ -91,8 +90,15 @@ def get_SD_pictures(description):
|
||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
||||
image.save(output_file.as_posix())
|
||||
pic_id += 1
|
||||
# lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
|
||||
newsize = (300, 300)
|
||||
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
width, height = image.size
|
||||
if (width > 300):
|
||||
height = int(height * (300 / width))
|
||||
width = 300
|
||||
elif (height > 300):
|
||||
width = int(width * (300 / height))
|
||||
height = 300
|
||||
newsize = (width, height)
|
||||
image = image.resize(newsize, Image.LANCZOS)
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
|
@ -4,6 +4,7 @@ from io import BytesIO
|
||||
import gradio as gr
|
||||
import torch
|
||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
from PIL import Image
|
||||
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
@ -25,10 +26,20 @@ def caption_image(raw_image):
|
||||
|
||||
def generate_chat_picture(picture, name1, name2):
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
width, height = picture.size
|
||||
if (width > 300):
|
||||
height = int(height * (300 / width))
|
||||
width = 300
|
||||
elif (height > 300):
|
||||
width = int(width * (300 / height))
|
||||
height = 300
|
||||
newsize = (width, height)
|
||||
picture = picture.resize(newsize, Image.LANCZOS)
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||
return text, visible_text
|
||||
|
||||
def ui():
|
||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import re
|
||||
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
@ -46,14 +47,10 @@ def load_model():
|
||||
model = load_model()
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
new_string = ""
|
||||
in_star = False
|
||||
for char in string:
|
||||
if char == '*':
|
||||
in_star = not in_star
|
||||
elif not in_star:
|
||||
new_string += char
|
||||
return new_string
|
||||
# regexp is way faster than repeated string concatenation!
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
def remove_tts_from_history(name1, name2):
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
|
Loading…
Reference in New Issue
Block a user