From b5486956ff4c80c14d64d7366fa5784fe6f7f70a Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 18 Jan 2025 18:48:49 +0800 Subject: [PATCH] added rudimentary support for outetts v0.3 500m and 1b models --- examples/tts/tts.cpp | 47 +++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index f78f76303..57b5ee33b 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -371,7 +371,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) { } // Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 -static std::string process_text(const std::string & text) { +static std::string process_text(const std::string & text, bool is_version_0_3) { // For now I skipped text romanization as I am unsure how to handle // uroman and MeCab implementations in C++ @@ -401,7 +401,7 @@ static std::string process_text(const std::string & text) { if (c == ' ') { prompt_clean += "<|text_sep|>"; */ - processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>"); + processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), is_version_0_3?"<|space|>":"<|text_sep|>"); return processed_text; } @@ -425,8 +425,7 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { prompt_add(prompt, vocab, "<|im_start|>\n", true, true); } -static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) { - const std::string& delimiter = "<|text_sep|>"; +static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const std::string& delimiter) { std::vector result; size_t start = 0; @@ -523,6 +522,11 @@ int main(int argc, char ** argv) { std::vector codes; std::vector guide_tokens; + //determine OuteTTS version and vocab code offset. v0.2 does not have <|space|>, but v0.3 does + const bool is_version_0_3 = common_tokenize(vocab,"<|space|>",false,true).size()==1; + //determine the offset of the first audio code token + const int cts_offset = common_tokenize(vocab,"<|0|>",false,true)[0]; + // process prompt and generate voice codes { LOG_INF("%s: constructing prompt ..\n", __func__); @@ -531,13 +535,17 @@ int main(int argc, char ** argv) { prompt_init(prompt_inp, vocab); - prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true); + if (is_version_0_3) { + prompt_add(prompt_inp, vocab, "<|text_start|>the<|space|>overall<|space|>package<|space|>from<|space|>just<|space|>two<|space|>people<|space|>is<|space|>pretty<|space|>remarkable<|space|>sure<|space|>i<|space|>have<|space|>some<|space|>critiques<|space|>about<|space|>some<|space|>of<|space|>the<|space|>gameplay<|space|>aspects<|space|>but<|space|>its<|space|>still<|space|>really<|space|>enjoyable<|space|>and<|space|>it<|space|>looks<|space|>lovely<|space|>", false, true); + } else { + prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true); + } // convert the input text into the necessary format expected by OuteTTS { - std::string prompt_clean = process_text(params.prompt); + std::string prompt_clean = process_text(params.prompt, is_version_0_3); if (params.vocoder.use_guide_tokens) { - guide_tokens = prepare_guide_tokens(vocab, prompt_clean); + guide_tokens = prepare_guide_tokens(vocab, prompt_clean, is_version_0_3?"<|space|>":"<|text_sep|>"); } LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); @@ -549,8 +557,8 @@ int main(int argc, char ** argv) { // disabled to save time on tokenizing each time // TODO: load voices from the json files -#if 0 - const std::string voice_data = R"(<|audio_start|> +#if 1 + std::string voice_data = R"(<|audio_start|> the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> @@ -582,12 +590,19 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|>< looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)"; - auto tmp = common_tokenize(vocab, voice_data, false, true); - printf("\n\n"); - for (int i = 0; i < tmp.size(); ++i) { - printf("%d, ", tmp[i]); + if (is_version_0_3) + { + voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_start\|>)"), ""); + voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_end\|>)"), "<|space|>"); } - printf("\n\n"); + + prompt_add(prompt_inp, vocab, voice_data, false, true); + + // printf("\n\n"); + // for (int i = 0; i < tmp.size(); ++i) { + // printf("%d, ", tmp[i]); + // } + // printf("\n\n"); #else prompt_add(prompt_inp, llama_tokens { 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, @@ -882,7 +897,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 } // remove all non-audio tokens (i.e. < 151672 || > 155772) - codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end()); + codes.erase(std::remove_if(codes.begin(), codes.end(), [cts_offset](llama_token t) { return t < cts_offset || t > (cts_offset+4100); }), codes.end()); { const std::string inp_txt = common_detokenize(ctx_ttc, codes, true); @@ -891,7 +906,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 } for (auto & token : codes) { - token -= 151672; + token -= cts_offset; } const auto t_voc_start = ggml_time_us();