mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Global model aded clobal tts_model to fix.
This commit is contained in:
parent
87fe313d6c
commit
6d87415d61
@ -1,5 +1,4 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import elevenlabs
|
import elevenlabs
|
||||||
@ -94,7 +93,7 @@ def history_modifier(history):
|
|||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string, state):
|
def output_modifier(string):
|
||||||
global params, wav_idx
|
global params, wav_idx
|
||||||
|
|
||||||
if not params['activate']:
|
if not params['activate']:
|
||||||
@ -109,7 +108,7 @@ def output_modifier(string, state):
|
|||||||
if string == '':
|
if string == '':
|
||||||
string = 'empty reply, try regenerating'
|
string = 'empty reply, try regenerating'
|
||||||
|
|
||||||
output_file = Path(f'extensions/elevenlabs_tts/outputs/{state["character_menu"]}_{int(time.time())}.mp3'.format(wav_idx))
|
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
|
||||||
print(f'Outputting audio to {str(output_file)}')
|
print(f'Outputting audio to {str(output_file)}')
|
||||||
try:
|
try:
|
||||||
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model'])
|
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model'])
|
||||||
@ -161,7 +160,7 @@ def ui():
|
|||||||
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
tts_model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
convert = gr.Button('Permanently replace audios with the message texts')
|
convert = gr.Button('Permanently replace audios with the message texts')
|
||||||
@ -191,7 +190,7 @@ def ui():
|
|||||||
activate.change(lambda x: params.update({'activate': x}), activate, None)
|
activate.change(lambda x: params.update({'activate': x}), activate, None)
|
||||||
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
||||||
api_key.change(update_api_key, api_key, None)
|
api_key.change(update_api_key, api_key, None)
|
||||||
tts_model.change(lambda x: params.update({'model': x}), tts_model, None)
|
model.change(lambda x: params.update({'model': x}), model, None)
|
||||||
# connect.click(check_valid_api, [], connection_status)
|
# connect.click(check_valid_api, [], connection_status)
|
||||||
refresh.click(refresh_voices_dd, [], voice)
|
refresh.click(refresh_voices_dd, [], voice)
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
|
@ -49,12 +49,12 @@ def load_model():
|
|||||||
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
|
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
|
||||||
if Path(model_path).is_file():
|
if Path(model_path).is_file():
|
||||||
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
|
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
|
||||||
tts_model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
||||||
else:
|
else:
|
||||||
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
|
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
|
||||||
tts_model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||||
tts_model.to(params['device'])
|
model.to(params['device'])
|
||||||
return tts_model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def remove_tts_from_history(history):
|
def remove_tts_from_history(history):
|
||||||
@ -105,10 +105,10 @@ def history_modifier(history):
|
|||||||
|
|
||||||
|
|
||||||
def output_modifier(string, state):
|
def output_modifier(string, state):
|
||||||
global tts_model, current_params, streaming_state
|
global model, current_params, streaming_state
|
||||||
for i in params:
|
for i in params:
|
||||||
if params[i] != current_params[i]:
|
if params[i] != current_params[i]:
|
||||||
tts_model = load_model()
|
model = load_model()
|
||||||
current_params = params.copy()
|
current_params = params.copy()
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ def output_modifier(string, state):
|
|||||||
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
|
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
|
||||||
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
||||||
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
||||||
tts_model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||||
|
|
||||||
autoplay = 'autoplay' if params['autoplay'] else ''
|
autoplay = 'autoplay' if params['autoplay'] else ''
|
||||||
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
||||||
@ -136,8 +136,8 @@ def output_modifier(string, state):
|
|||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
global tts_model
|
global model
|
||||||
tts_model = load_model()
|
model = load_model()
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
@ -13,7 +13,6 @@ is_seq2seq = False
|
|||||||
model_name = "None"
|
model_name = "None"
|
||||||
lora_names = []
|
lora_names = []
|
||||||
model_dirty_from_training = False
|
model_dirty_from_training = False
|
||||||
tts_model = None
|
|
||||||
|
|
||||||
# Chat variables
|
# Chat variables
|
||||||
stop_everything = False
|
stop_everything = False
|
||||||
|
Loading…
Reference in New Issue
Block a user