From e563b015d895439362f452559811568bedb6b054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com> Date: Fri, 7 Apr 2023 18:15:57 +0300 Subject: [PATCH] Silero TTS offline cache (#628) --- extensions/silero_tts/script.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index b001ae5a..d601ff53 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -21,6 +21,7 @@ params = { 'autoplay': True, 'voice_pitch': 'medium', 'voice_speed': 'medium', + 'local_cache_path': '' # User can override the default cache path to something other via settings.json } current_params = params.copy() @@ -44,14 +45,18 @@ def xmlesc(txt): def load_model(): - model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) + torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path'] + model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt" + if Path(model_path).is_file(): + print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}') + model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True) + else: + print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...') + model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model.to(params['device']) return model -model = load_model() - - def remove_tts_from_history(name1, name2, mode): for i, entry in enumerate(shared.history['internal']): shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] @@ -132,6 +137,11 @@ def bot_prefix_modifier(string): return string +def setup(): + global model + model = load_model() + + def ui(): # Gradio elements with gr.Accordion("Silero TTS"):