Silero TTS offline cache (#628)

This commit is contained in:
Φφ 2023-04-07 18:15:57 +03:00 committed by GitHub
parent 1c413ed593
commit e563b015d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,6 +21,7 @@ params = {
'autoplay': True, 'autoplay': True,
'voice_pitch': 'medium', 'voice_pitch': 'medium',
'voice_speed': 'medium', 'voice_speed': 'medium',
'local_cache_path': '' # User can override the default cache path to something other via settings.json
} }
current_params = params.copy() current_params = params.copy()
@ -44,14 +45,18 @@ def xmlesc(txt):
def load_model(): def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
if Path(model_path).is_file():
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
else:
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device']) model.to(params['device'])
return model return model
model = load_model()
def remove_tts_from_history(name1, name2, mode): def remove_tts_from_history(name1, name2, mode):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
@ -132,6 +137,11 @@ def bot_prefix_modifier(string):
return string return string
def setup():
global model
model = load_model()
def ui(): def ui():
# Gradio elements # Gradio elements
with gr.Accordion("Silero TTS"): with gr.Accordion("Silero TTS"):