mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-23 18:09:18 +01:00
Add repetition penalty (#20)
* Adding repeat penalization * Update utils.h * Update utils.cpp * Numeric fix Should probably still scale by temp even if penalized * Update comments, more proper application I see that numbers can go negative so a fix from a referenced commit * Minor formatting --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
702fddf5c5
commit
129c7d1ea8
14
main.cpp
14
main.cpp
@ -792,7 +792,7 @@ int main(int argc, char ** argv) {
|
|||||||
printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
|
printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p);
|
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||||
printf("\n\n");
|
printf("\n\n");
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> embd;
|
std::vector<gpt_vocab::id> embd;
|
||||||
@ -801,6 +801,10 @@ int main(int argc, char ** argv) {
|
|||||||
size_t mem_per_token = 0;
|
size_t mem_per_token = 0;
|
||||||
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||||
|
|
||||||
|
int last_n_size = params.repeat_last_n;
|
||||||
|
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
||||||
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
|
|
||||||
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
|
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
|
||||||
// predict
|
// predict
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
@ -821,6 +825,7 @@ int main(int argc, char ** argv) {
|
|||||||
// sample next token
|
// sample next token
|
||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
|
const float repeat_penalty = params.repeat_penalty;
|
||||||
|
|
||||||
const int n_vocab = model.hparams.n_vocab;
|
const int n_vocab = model.hparams.n_vocab;
|
||||||
|
|
||||||
@ -829,7 +834,10 @@ int main(int argc, char ** argv) {
|
|||||||
{
|
{
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
|
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
|
||||||
|
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
last_n_tokens.push_back(id);
|
||||||
|
|
||||||
t_sample_us += ggml_time_us() - t_start_sample_us;
|
t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
}
|
}
|
||||||
@ -840,6 +848,8 @@ int main(int argc, char ** argv) {
|
|||||||
// if here, it means we are still processing the input prompt
|
// if here, it means we are still processing the input prompt
|
||||||
for (int k = i; k < embd_inp.size(); k++) {
|
for (int k = i; k < embd_inp.size(); k++) {
|
||||||
embd.push_back(embd_inp[k]);
|
embd.push_back(embd_inp[k]);
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
last_n_tokens.push_back(embd_inp[k]);
|
||||||
if (embd.size() > params.n_batch) {
|
if (embd.size() > params.n_batch) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
19
utils.cpp
19
utils.cpp
@ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.top_p = std::stof(argv[++i]);
|
params.top_p = std::stof(argv[++i]);
|
||||||
} else if (arg == "--temp") {
|
} else if (arg == "--temp") {
|
||||||
params.temp = std::stof(argv[++i]);
|
params.temp = std::stof(argv[++i]);
|
||||||
|
} else if (arg == "--repeat_last_n") {
|
||||||
|
params.repeat_last_n = std::stoi(argv[++i]);
|
||||||
|
} else if (arg == "--repeat_penalty") {
|
||||||
|
params.repeat_penalty = std::stof(argv[++i]);
|
||||||
} else if (arg == "-b" || arg == "--batch_size") {
|
} else if (arg == "-b" || arg == "--batch_size") {
|
||||||
params.n_batch = std::stoi(argv[++i]);
|
params.n_batch = std::stoi(argv[++i]);
|
||||||
} else if (arg == "-m" || arg == "--model") {
|
} else if (arg == "-m" || arg == "--model") {
|
||||||
@ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
|
|||||||
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
|
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
|
||||||
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
||||||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
||||||
|
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
|
||||||
|
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
|
||||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
||||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||||
@ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
|||||||
gpt_vocab::id llama_sample_top_p(
|
gpt_vocab::id llama_sample_top_p(
|
||||||
const gpt_vocab & vocab,
|
const gpt_vocab & vocab,
|
||||||
const float * logits,
|
const float * logits,
|
||||||
|
std::vector<gpt_vocab::id> & last_n_tokens,
|
||||||
|
double repeat_penalty,
|
||||||
double top_p,
|
double top_p,
|
||||||
double temp,
|
double temp,
|
||||||
std::mt19937 & rng) {
|
std::mt19937 & rng) {
|
||||||
@ -383,9 +391,20 @@ gpt_vocab::id llama_sample_top_p(
|
|||||||
{
|
{
|
||||||
const double scale = 1.0/temp;
|
const double scale = 1.0/temp;
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
|
// repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||||
|
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||||
|
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||||
|
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
|
if (logits[i] < 0.0) {
|
||||||
|
logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
|
||||||
|
} else {
|
||||||
|
logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::sort(
|
std::sort(
|
||||||
logits_id.begin(),
|
logits_id.begin(),
|
||||||
|
4
utils.h
4
utils.h
@ -16,11 +16,13 @@ struct gpt_params {
|
|||||||
int32_t seed = -1; // RNG seed
|
int32_t seed = -1; // RNG seed
|
||||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
int32_t n_predict = 128; // new tokens to predict
|
int32_t n_predict = 128; // new tokens to predict
|
||||||
|
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
int32_t top_k = 40; // unused
|
int32_t top_k = 40; // unused
|
||||||
float top_p = 0.95f;
|
float top_p = 0.95f;
|
||||||
float temp = 0.80f;
|
float temp = 0.80f;
|
||||||
|
float repeat_penalty = 1.30f;
|
||||||
|
|
||||||
int32_t n_batch = 8; // batch size for prompt processing
|
int32_t n_batch = 8; // batch size for prompt processing
|
||||||
|
|
||||||
@ -89,6 +91,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
|||||||
gpt_vocab::id llama_sample_top_p(
|
gpt_vocab::id llama_sample_top_p(
|
||||||
const gpt_vocab & vocab,
|
const gpt_vocab & vocab,
|
||||||
const float * logits,
|
const float * logits,
|
||||||
|
std::vector<gpt_vocab::id> & last_n_tokens,
|
||||||
|
double repeat_penalty,
|
||||||
double top_p,
|
double top_p,
|
||||||
double temp,
|
double temp,
|
||||||
std::mt19937 & rng);
|
std::mt19937 & rng);
|
||||||
|
Loading…
Reference in New Issue
Block a user