Global model aded clobal tts_model to fix.

This commit is contained in:
Michael Sullivan 2023-07-21 02:57:44 -05:00
parent 87fe313d6c
commit 6d87415d61
3 changed files with 13 additions and 15 deletions

View File

@ -1,5 +1,4 @@
import re import re
import time
from pathlib import Path from pathlib import Path
import elevenlabs import elevenlabs
@ -94,7 +93,7 @@ def history_modifier(history):
return history return history
def output_modifier(string, state): def output_modifier(string):
global params, wav_idx global params, wav_idx
if not params['activate']: if not params['activate']:
@ -109,7 +108,7 @@ def output_modifier(string, state):
if string == '': if string == '':
string = 'empty reply, try regenerating' string = 'empty reply, try regenerating'
output_file = Path(f'extensions/elevenlabs_tts/outputs/{state["character_menu"]}_{int(time.time())}.mp3'.format(wav_idx)) output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
print(f'Outputting audio to {str(output_file)}') print(f'Outputting audio to {str(output_file)}')
try: try:
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model']) audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model'])
@ -161,7 +160,7 @@ def ui():
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
with gr.Row(): with gr.Row():
tts_model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model') model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
with gr.Row(): with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts') convert = gr.Button('Permanently replace audios with the message texts')
@ -191,7 +190,7 @@ def ui():
activate.change(lambda x: params.update({'activate': x}), activate, None) activate.change(lambda x: params.update({'activate': x}), activate, None)
voice.change(lambda x: params.update({'selected_voice': x}), voice, None) voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(update_api_key, api_key, None) api_key.change(update_api_key, api_key, None)
tts_model.change(lambda x: params.update({'model': x}), tts_model, None) model.change(lambda x: params.update({'model': x}), model, None)
# connect.click(check_valid_api, [], connection_status) # connect.click(check_valid_api, [], connection_status)
refresh.click(refresh_voices_dd, [], voice) refresh.click(refresh_voices_dd, [], voice)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend

View File

@ -49,12 +49,12 @@ def load_model():
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt" model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
if Path(model_path).is_file(): if Path(model_path).is_file():
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}') print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
tts_model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True) model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
else: else:
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...') print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
tts_model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
tts_model.to(params['device']) model.to(params['device'])
return tts_model return model
def remove_tts_from_history(history): def remove_tts_from_history(history):
@ -105,10 +105,10 @@ def history_modifier(history):
def output_modifier(string, state): def output_modifier(string, state):
global tts_model, current_params, streaming_state global model, current_params, streaming_state
for i in params: for i in params:
if params[i] != current_params[i]: if params[i] != current_params[i]:
tts_model = load_model() model = load_model()
current_params = params.copy() current_params = params.copy()
break break
@ -124,7 +124,7 @@ def output_modifier(string, state):
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav') output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch']) prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>' silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
tts_model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
autoplay = 'autoplay' if params['autoplay'] else '' autoplay = 'autoplay' if params['autoplay'] else ''
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>' string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
@ -136,8 +136,8 @@ def output_modifier(string, state):
def setup(): def setup():
global tts_model global model
tts_model = load_model() model = load_model()
def ui(): def ui():

View File

@ -13,7 +13,6 @@ is_seq2seq = False
model_name = "None" model_name = "None"
lora_names = [] lora_names = []
model_dirty_from_training = False model_dirty_from_training = False
tts_model = None
# Chat variables # Chat variables
stop_everything = False stop_everything = False