Add "remove_trailing_dots" option to XTTSv2

This commit is contained in:
oobabooga 2023-11-20 18:33:29 -08:00
parent 8dc9ec3491
commit 829c6d4f78

View File

@ -13,6 +13,7 @@ from modules.utils import gradio
try: try:
from TTS.api import TTS from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
except ModuleNotFoundError: except ModuleNotFoundError:
logger.error( logger.error(
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension." "Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
@ -30,6 +31,7 @@ params = {
"activate": True, "activate": True,
"autoplay": True, "autoplay": True,
"show_text": False, "show_text": False,
"remove_trailing_dots": False,
"voice": "female_01.wav", "voice": "female_01.wav",
"language": "English", "language": "English",
"model_name": "tts_models/multilingual/multi-dataset/xtts_v2", "model_name": "tts_models/multilingual/multi-dataset/xtts_v2",
@ -52,6 +54,24 @@ def preprocess(raw_input):
return raw_input return raw_input
def new_split_into_sentences(self, text):
sentences = self.seg.segment(text)
if params['remove_trailing_dots']:
sentences_without_dots = []
for sentence in sentences:
if sentence.endswith('.') and not sentence.endswith('...'):
sentence = sentence[:-1]
sentences_without_dots.append(sentence)
return sentences_without_dots
else:
return sentences
Synthesizer.split_into_sentences = new_split_into_sentences
def load_model(): def load_model():
model = TTS(params["model_name"]).to(params["device"]) model = TTS(params["model_name"]).to(params["device"])
return model return model
@ -170,6 +190,7 @@ def ui():
with gr.Row(): with gr.Row():
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player') show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
remove_trailing_dots = gr.Checkbox(value=params['remove_trailing_dots'], label='Remove trailing "." from text segments before converting to audio')
with gr.Row(): with gr.Row():
with gr.Row(): with gr.Row():
@ -209,6 +230,7 @@ def ui():
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None) autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
remove_trailing_dots.change(lambda x: params.update({"remove_trailing_dots": x}), remove_trailing_dots, None)
voice.change(lambda x: params.update({"voice": x}), voice, None) voice.change(lambda x: params.update({"voice": x}), voice, None)
language.change(lambda x: params.update({"language": x}), language, None) language.change(lambda x: params.update({"language": x}), language, None)