mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Global model aded clobal tts_model to fix.
This commit is contained in:
parent
87fe313d6c
commit
6d87415d61
@ -1,5 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import elevenlabs
|
||||
@ -94,7 +93,7 @@ def history_modifier(history):
|
||||
return history
|
||||
|
||||
|
||||
def output_modifier(string, state):
|
||||
def output_modifier(string):
|
||||
global params, wav_idx
|
||||
|
||||
if not params['activate']:
|
||||
@ -109,7 +108,7 @@ def output_modifier(string, state):
|
||||
if string == '':
|
||||
string = 'empty reply, try regenerating'
|
||||
|
||||
output_file = Path(f'extensions/elevenlabs_tts/outputs/{state["character_menu"]}_{int(time.time())}.mp3'.format(wav_idx))
|
||||
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
|
||||
print(f'Outputting audio to {str(output_file)}')
|
||||
try:
|
||||
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model'])
|
||||
@ -161,7 +160,7 @@ def ui():
|
||||
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
||||
|
||||
with gr.Row():
|
||||
tts_model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
||||
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
||||
|
||||
with gr.Row():
|
||||
convert = gr.Button('Permanently replace audios with the message texts')
|
||||
@ -191,7 +190,7 @@ def ui():
|
||||
activate.change(lambda x: params.update({'activate': x}), activate, None)
|
||||
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
||||
api_key.change(update_api_key, api_key, None)
|
||||
tts_model.change(lambda x: params.update({'model': x}), tts_model, None)
|
||||
model.change(lambda x: params.update({'model': x}), model, None)
|
||||
# connect.click(check_valid_api, [], connection_status)
|
||||
refresh.click(refresh_voices_dd, [], voice)
|
||||
# Event functions to update the parameters in the backend
|
||||
|
@ -49,12 +49,12 @@ def load_model():
|
||||
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
|
||||
if Path(model_path).is_file():
|
||||
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
|
||||
tts_model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
||||
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
||||
else:
|
||||
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
|
||||
tts_model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||
tts_model.to(params['device'])
|
||||
return tts_model
|
||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||
model.to(params['device'])
|
||||
return model
|
||||
|
||||
|
||||
def remove_tts_from_history(history):
|
||||
@ -105,10 +105,10 @@ def history_modifier(history):
|
||||
|
||||
|
||||
def output_modifier(string, state):
|
||||
global tts_model, current_params, streaming_state
|
||||
global model, current_params, streaming_state
|
||||
for i in params:
|
||||
if params[i] != current_params[i]:
|
||||
tts_model = load_model()
|
||||
model = load_model()
|
||||
current_params = params.copy()
|
||||
break
|
||||
|
||||
@ -124,7 +124,7 @@ def output_modifier(string, state):
|
||||
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
|
||||
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
||||
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
||||
tts_model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||
|
||||
autoplay = 'autoplay' if params['autoplay'] else ''
|
||||
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
||||
@ -136,8 +136,8 @@ def output_modifier(string, state):
|
||||
|
||||
|
||||
def setup():
|
||||
global tts_model
|
||||
tts_model = load_model()
|
||||
global model
|
||||
model = load_model()
|
||||
|
||||
|
||||
def ui():
|
||||
|
@ -13,7 +13,6 @@ is_seq2seq = False
|
||||
model_name = "None"
|
||||
lora_names = []
|
||||
model_dirty_from_training = False
|
||||
tts_model = None
|
||||
|
||||
# Chat variables
|
||||
stop_everything = False
|
||||
|
Loading…
Reference in New Issue
Block a user