tts : add guide tokens support (#11186)

* Added the ability to use guide tokens for OuteTTS, greatly improving TTS recitation accuracy over long input sequences.

* applied linting suggestions, updated to latest llama_vocab changes, added a safety check, added newline to guide token start
This commit is contained in:
LostRuins Concedo 2025-01-18 18:20:57 +08:00 committed by GitHub
parent 44e18ef939
commit 6390a998bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 1 deletions

View File

@ -2254,6 +2254,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.vocoder.model = value;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--tts-use-guide-tokens"},
"Use guide tokens to improve TTS word recall",
[](common_params & params) {
params.vocoder.use_guide_tokens = true;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
// model-specific
add_opt(common_arg(

View File

@ -184,6 +184,8 @@ struct common_params_vocoder {
std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};
struct common_params {

View File

@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";
std::vector<llama_token> result;
size_t start = 0;
size_t end = str.find(delimiter);
//first token is always a newline, as it was not previously added
result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
while (end != std::string::npos) {
std::string current_word = str.substr(start, end - start);
auto tmp = common_tokenize(vocab, current_word, false, true);
result.push_back(tmp[0]);
start = end + delimiter.length();
end = str.find(delimiter, start);
}
// Add the last part
std::string current_word = str.substr(start);
auto tmp = common_tokenize(vocab, current_word, false, true);
if (tmp.size() > 0) {
result.push_back(tmp[0]);
}
return result;
}
int main(int argc, char ** argv) {
common_params params;
@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();
std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;
// process prompt and generate voice codes
{
@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
}
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
int n_past = batch.n_tokens;
int n_decode = 0;
bool next_token_uses_guide_token = true;
while (n_decode <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
continue;
}
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
}
//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);
common_sampler_accept(smpl[i], new_token_id, true);