mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Merge pull request #489 from Brawlence/ext-fixes
Extensions performance & memory optimisations
This commit is contained in:
commit
d5fc1bead7
@ -1,11 +1,11 @@
|
|||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import modules.shared as shared
|
||||||
from elevenlabslib import ElevenLabsUser
|
from elevenlabslib import ElevenLabsUser
|
||||||
from elevenlabslib.helpers import save_bytes_to_path
|
from elevenlabslib.helpers import save_bytes_to_path
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'activate': True,
|
'activate': True,
|
||||||
'api_key': '12345',
|
'api_key': '12345',
|
||||||
@ -52,14 +52,9 @@ def refresh_voices():
|
|||||||
return
|
return
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
new_string = ""
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
in_star = False
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
for char in string:
|
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||||
if char == '*':
|
|
||||||
in_star = not in_star
|
|
||||||
elif not in_star:
|
|
||||||
new_string += char
|
|
||||||
return new_string
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
@ -115,4 +110,4 @@ def ui():
|
|||||||
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
||||||
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
|
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
|
||||||
connect.click(check_valid_api, [], connection_status)
|
connect.click(check_valid_api, [], connection_status)
|
||||||
connect.click(refresh_voices, [], voice)
|
connect.click(refresh_voices, [], voice)
|
@ -1,15 +1,15 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import modules.chat as chat
|
||||||
|
import modules.shared as shared
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.chat as chat
|
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
# parameters which can be customized in settings.json of webui
|
# parameters which can be customized in settings.json of webui
|
||||||
@ -31,14 +31,9 @@ picture_response = False # specifies if the next model response should appear as
|
|||||||
pic_id = 0
|
pic_id = 0
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
new_string = ""
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
in_star = False
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
for char in string:
|
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||||
if char == '*':
|
|
||||||
in_star = not in_star
|
|
||||||
elif not in_star:
|
|
||||||
new_string += char
|
|
||||||
return new_string
|
|
||||||
|
|
||||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
@ -54,6 +49,8 @@ def input_modifier(string):
|
|||||||
mediums = ['image', 'pic', 'picture', 'photo']
|
mediums = ['image', 'pic', 'picture', 'photo']
|
||||||
subjects = ['yourself', 'own']
|
subjects = ['yourself', 'own']
|
||||||
lowstr = string.lower()
|
lowstr = string.lower()
|
||||||
|
|
||||||
|
# TODO: refactor out to separate handler and also replace detection with a regexp
|
||||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
||||||
picture_response = True
|
picture_response = True
|
||||||
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||||
@ -91,9 +88,8 @@ def get_SD_pictures(description):
|
|||||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
||||||
image.save(output_file.as_posix())
|
image.save(output_file.as_posix())
|
||||||
pic_id += 1
|
pic_id += 1
|
||||||
# lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
|
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||||
newsize = (300, 300)
|
image.thumbnail((300, 300))
|
||||||
image = image.resize(newsize, Image.LANCZOS)
|
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="JPEG")
|
image.save(buffered, format="JPEG")
|
||||||
buffered.seek(0)
|
buffered.seek(0)
|
||||||
@ -180,4 +176,4 @@ def ui():
|
|||||||
|
|
||||||
force_btn.click(force_pic)
|
force_btn.click(force_pic)
|
||||||
generate_now_btn.click(force_pic)
|
generate_now_btn.click(force_pic)
|
||||||
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
@ -2,11 +2,11 @@ import base64
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
|
||||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
|
||||||
|
|
||||||
import modules.chat as chat
|
import modules.chat as chat
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||||
|
|
||||||
# If 'state' is True, will hijack the next chat generation with
|
# If 'state' is True, will hijack the next chat generation with
|
||||||
# custom input text given by 'value' in the format [text, visible_text]
|
# custom input text given by 'value' in the format [text, visible_text]
|
||||||
@ -25,10 +25,12 @@ def caption_image(raw_image):
|
|||||||
|
|
||||||
def generate_chat_picture(picture, name1, name2):
|
def generate_chat_picture(picture, name1, name2):
|
||||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||||
|
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||||
|
picture.thumbnail((300, 300))
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
picture.save(buffer, format="JPEG")
|
picture.save(buffer, format="JPEG")
|
||||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
|
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||||
return text, visible_text
|
return text, visible_text
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
|
||||||
|
|
||||||
import modules.chat as chat
|
import modules.chat as chat
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
import torch
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
@ -46,14 +46,9 @@ def load_model():
|
|||||||
model = load_model()
|
model = load_model()
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
new_string = ""
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
in_star = False
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
for char in string:
|
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||||
if char == '*':
|
|
||||||
in_star = not in_star
|
|
||||||
elif not in_star:
|
|
||||||
new_string += char
|
|
||||||
return new_string
|
|
||||||
|
|
||||||
def remove_tts_from_history(name1, name2):
|
def remove_tts_from_history(name1, name2):
|
||||||
for i, entry in enumerate(shared.history['internal']):
|
for i, entry in enumerate(shared.history['internal']):
|
||||||
@ -166,4 +161,4 @@ def ui():
|
|||||||
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
|
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
|
||||||
voice.change(lambda x: params.update({"speaker": x}), voice, None)
|
voice.change(lambda x: params.update({"speaker": x}), voice, None)
|
||||||
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
|
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
|
||||||
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
|
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
|
Loading…
Reference in New Issue
Block a user