From f6bf74dcd593308f5104d5c100b4901c5b0fb365 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 14 Feb 2023 15:06:06 -0300 Subject: [PATCH] Add Silero TTS extension --- extensions/silero_tts/requirements.txt | 6 ++ extensions/silero_tts/script.py | 79 ++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 extensions/silero_tts/requirements.txt create mode 100644 extensions/silero_tts/script.py diff --git a/extensions/silero_tts/requirements.txt b/extensions/silero_tts/requirements.txt new file mode 100644 index 00000000..f2f0bff5 --- /dev/null +++ b/extensions/silero_tts/requirements.txt @@ -0,0 +1,6 @@ +ipython +omegaconf +pydub +PyYAML +torch +torchaudio diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py new file mode 100644 index 00000000..8e86ff17 --- /dev/null +++ b/extensions/silero_tts/script.py @@ -0,0 +1,79 @@ +import asyncio +from pathlib import Path + +import torch + +torch._C._jit_set_profiling_mode(False) + +params = { + 'speaker': 'en_21', + 'language': 'en', + 'model_id': 'v3_en', + 'sample_rate': 48000, + 'device': 'cpu', +} +current_params = params.copy() +wav_idx = 0 + +def load_model(): + model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) + model.to(params['device']) + return model +model = load_model() + +def remove_surrounded_chars(string): + new_string = "" + in_star = False + for char in string: + if char == '*': + in_star = not in_star + elif not in_star: + new_string += char + return new_string + +def input_modifier(string): + """ + This function is applied to your text inputs before + they are fed into the model. + """ + + return string + +def output_modifier(string): + """ + This function is applied to the model outputs. + """ + + global wav_idx, model, current_params + + for i in params: + if params[i] != current_params[i]: + model = load_model() + current_params = params.copy() + break + + string = remove_surrounded_chars(string) + string = string.replace('"', '') + string = string.replace('“', '') + string = string.replace('\n', ' ') + string = string.strip() + + if string == '': + string = 'empty reply, try regenerating' + + output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav') + audio = model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) + + string = f'' + wav_idx += 1 + + return string + +def bot_prefix_modifier(string): + """ + This function is only applied in chat mode. It modifies + the prefix text for the Bot and can be used to bias its + behavior. + """ + + return string