mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 08:00:42 +01:00
added rudimentary support for outetts v0.3 500m and 1b models
This commit is contained in:
parent
6390a998bf
commit
b5486956ff
@ -371,7 +371,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
|
||||
}
|
||||
|
||||
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
|
||||
static std::string process_text(const std::string & text) {
|
||||
static std::string process_text(const std::string & text, bool is_version_0_3) {
|
||||
|
||||
// For now I skipped text romanization as I am unsure how to handle
|
||||
// uroman and MeCab implementations in C++
|
||||
@ -401,7 +401,7 @@ static std::string process_text(const std::string & text) {
|
||||
if (c == ' ') {
|
||||
prompt_clean += "<|text_sep|>";
|
||||
*/
|
||||
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
|
||||
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), is_version_0_3?"<|space|>":"<|text_sep|>");
|
||||
|
||||
return processed_text;
|
||||
}
|
||||
@ -425,8 +425,7 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
|
||||
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
||||
}
|
||||
|
||||
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
|
||||
const std::string& delimiter = "<|text_sep|>";
|
||||
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const std::string& delimiter) {
|
||||
|
||||
std::vector<llama_token> result;
|
||||
size_t start = 0;
|
||||
@ -523,6 +522,11 @@ int main(int argc, char ** argv) {
|
||||
std::vector<llama_token> codes;
|
||||
std::vector<llama_token> guide_tokens;
|
||||
|
||||
//determine OuteTTS version and vocab code offset. v0.2 does not have <|space|>, but v0.3 does
|
||||
const bool is_version_0_3 = common_tokenize(vocab,"<|space|>",false,true).size()==1;
|
||||
//determine the offset of the first audio code token
|
||||
const int cts_offset = common_tokenize(vocab,"<|0|>",false,true)[0];
|
||||
|
||||
// process prompt and generate voice codes
|
||||
{
|
||||
LOG_INF("%s: constructing prompt ..\n", __func__);
|
||||
@ -531,13 +535,17 @@ int main(int argc, char ** argv) {
|
||||
|
||||
prompt_init(prompt_inp, vocab);
|
||||
|
||||
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
||||
if (is_version_0_3) {
|
||||
prompt_add(prompt_inp, vocab, "<|text_start|>the<|space|>overall<|space|>package<|space|>from<|space|>just<|space|>two<|space|>people<|space|>is<|space|>pretty<|space|>remarkable<|space|>sure<|space|>i<|space|>have<|space|>some<|space|>critiques<|space|>about<|space|>some<|space|>of<|space|>the<|space|>gameplay<|space|>aspects<|space|>but<|space|>its<|space|>still<|space|>really<|space|>enjoyable<|space|>and<|space|>it<|space|>looks<|space|>lovely<|space|>", false, true);
|
||||
} else {
|
||||
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
||||
}
|
||||
|
||||
// convert the input text into the necessary format expected by OuteTTS
|
||||
{
|
||||
std::string prompt_clean = process_text(params.prompt);
|
||||
std::string prompt_clean = process_text(params.prompt, is_version_0_3);
|
||||
if (params.vocoder.use_guide_tokens) {
|
||||
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
|
||||
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, is_version_0_3?"<|space|>":"<|text_sep|>");
|
||||
}
|
||||
|
||||
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
||||
@ -549,8 +557,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// disabled to save time on tokenizing each time
|
||||
// TODO: load voices from the json files
|
||||
#if 0
|
||||
const std::string voice_data = R"(<|audio_start|>
|
||||
#if 1
|
||||
std::string voice_data = R"(<|audio_start|>
|
||||
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
|
||||
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
|
||||
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
|
||||
@ -582,12 +590,19 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
|
||||
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
|
||||
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
|
||||
|
||||
auto tmp = common_tokenize(vocab, voice_data, false, true);
|
||||
printf("\n\n");
|
||||
for (int i = 0; i < tmp.size(); ++i) {
|
||||
printf("%d, ", tmp[i]);
|
||||
if (is_version_0_3)
|
||||
{
|
||||
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_start\|>)"), "");
|
||||
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
|
||||
}
|
||||
printf("\n\n");
|
||||
|
||||
prompt_add(prompt_inp, vocab, voice_data, false, true);
|
||||
|
||||
// printf("\n\n");
|
||||
// for (int i = 0; i < tmp.size(); ++i) {
|
||||
// printf("%d, ", tmp[i]);
|
||||
// }
|
||||
// printf("\n\n");
|
||||
#else
|
||||
prompt_add(prompt_inp, llama_tokens {
|
||||
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
|
||||
@ -882,7 +897,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
}
|
||||
|
||||
// remove all non-audio tokens (i.e. < 151672 || > 155772)
|
||||
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
|
||||
codes.erase(std::remove_if(codes.begin(), codes.end(), [cts_offset](llama_token t) { return t < cts_offset || t > (cts_offset+4100); }), codes.end());
|
||||
|
||||
{
|
||||
const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
|
||||
@ -891,7 +906,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
}
|
||||
|
||||
for (auto & token : codes) {
|
||||
token -= 151672;
|
||||
token -= cts_offset;
|
||||
}
|
||||
|
||||
const auto t_voc_start = ggml_time_us();
|
||||
|
Loading…
Reference in New Issue
Block a user