model in the TTS extensions clobbered global model

This commit is contained in:
Michael Sullivan 2023-07-21 02:33:11 -05:00
parent 87926d033d
commit 1c68c05b66
3 changed files with 15 additions and 13 deletions

View File

@ -1,4 +1,5 @@
import re
import time
from pathlib import Path
import elevenlabs
@ -93,7 +94,7 @@ def history_modifier(history):
return history
def output_modifier(string):
def output_modifier(string, state):
global params, wav_idx
if not params['activate']:
@ -108,7 +109,7 @@ def output_modifier(string):
if string == '':
string = 'empty reply, try regenerating'
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
output_file = Path(f'extensions/elevenlabs_tts/outputs/{state["character_menu"]}_{int(time.time())}.mp3'.format(wav_idx))
print(f'Outputting audio to {str(output_file)}')
try:
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model'])
@ -160,7 +161,7 @@ def ui():
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
with gr.Row():
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
tts_model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts')
@ -190,7 +191,7 @@ def ui():
activate.change(lambda x: params.update({'activate': x}), activate, None)
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(update_api_key, api_key, None)
model.change(lambda x: params.update({'model': x}), model, None)
tts_model.change(lambda x: params.update({'model': x}), tts_model, None)
# connect.click(check_valid_api, [], connection_status)
refresh.click(refresh_voices_dd, [], voice)
# Event functions to update the parameters in the backend

View File

@ -49,12 +49,12 @@ def load_model():
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
if Path(model_path).is_file():
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
tts_model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
else:
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device'])
return model
tts_model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
tts_model.to(params['device'])
return tts_model
def remove_tts_from_history(history):
@ -105,10 +105,10 @@ def history_modifier(history):
def output_modifier(string, state):
global model, current_params, streaming_state
global tts_model, current_params, streaming_state
for i in params:
if params[i] != current_params[i]:
model = load_model()
tts_model = load_model()
current_params = params.copy()
break
@ -124,7 +124,7 @@ def output_modifier(string, state):
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
tts_model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
autoplay = 'autoplay' if params['autoplay'] else ''
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
@ -136,8 +136,8 @@ def output_modifier(string, state):
def setup():
global model
model = load_model()
global tts_model
tts_model = load_model()
def ui():

View File

@ -13,6 +13,7 @@ is_seq2seq = False
model_name = "None"
lora_names = []
model_dirty_from_training = False
tts_model = None
# Chat variables
stop_everything = False