diff --git a/common/arg.cpp b/common/arg.cpp index 9069950eb..dede335fb 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2254,6 +2254,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.vocoder.model = value; } ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--tts-use-guide-tokens"}, + "Use guide tokens to improve TTS word recall", + [](common_params & params) { + params.vocoder.use_guide_tokens = true; + } + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); // model-specific add_opt(common_arg( diff --git a/common/common.h b/common/common.h index 691141d6b..3bcc637cc 100644 --- a/common/common.h +++ b/common/common.h @@ -184,6 +184,8 @@ struct common_params_vocoder { std::string model = ""; // model path // NOLINT std::string model_url = ""; // model url to download // NOLINT + + bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; struct common_params { diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 5a9161181..f78f76303 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { prompt_add(prompt, vocab, "<|im_start|>\n", true, true); } +static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) { + const std::string& delimiter = "<|text_sep|>"; + + std::vector result; + size_t start = 0; + size_t end = str.find(delimiter); + + //first token is always a newline, as it was not previously added + result.push_back(common_tokenize(vocab, "\n", false, true)[0]); + + while (end != std::string::npos) { + std::string current_word = str.substr(start, end - start); + auto tmp = common_tokenize(vocab, current_word, false, true); + result.push_back(tmp[0]); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + // Add the last part + std::string current_word = str.substr(start); + auto tmp = common_tokenize(vocab, current_word, false, true); + if (tmp.size() > 0) { + result.push_back(tmp[0]); + } + return result; +} + int main(int argc, char ** argv) { common_params params; @@ -494,6 +521,7 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); std::vector codes; + std::vector guide_tokens; // process prompt and generate voice codes { @@ -508,6 +536,9 @@ int main(int argc, char ** argv) { // convert the input text into the necessary format expected by OuteTTS { std::string prompt_clean = process_text(params.prompt); + if (params.vocoder.use_guide_tokens) { + guide_tokens = prepare_guide_tokens(vocab, prompt_clean); + } LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); @@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 int n_past = batch.n_tokens; int n_decode = 0; + bool next_token_uses_guide_token = true; + while (n_decode <= n_predict) { // prepare the next batch common_batch_clear(batch); @@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 continue; } - const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]); + llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]); + + //guide tokens help prevent hallucinations by forcing the TTS to use the correct word + if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) { + llama_token guide_token = guide_tokens[0]; + guide_tokens.erase(guide_tokens.begin()); + new_token_id = guide_token; //ensure correct word fragment is used + } + + //this is the token id that always precedes a new word + next_token_uses_guide_token = (new_token_id == 198); common_sampler_accept(smpl[i], new_token_id, true);