mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-31 06:03:11 +01:00
initial sampling changes:
This commit is contained in:
parent
f7cd13301c
commit
ddc3c2208a
@ -95,6 +95,7 @@ enum common_sampler_type {
|
|||||||
COMMON_SAMPLER_TYPE_XTC = 8,
|
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11
|
||||||
};
|
};
|
||||||
|
|
||||||
// dimensionality reduction methods, used by cvector-generator
|
// dimensionality reduction methods, used by cvector-generator
|
||||||
@ -128,6 +129,7 @@ struct common_params_sampling {
|
|||||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
|
int32_t top_n_sigma = 2;
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
@ -146,6 +148,7 @@ struct common_params_sampling {
|
|||||||
COMMON_SAMPLER_TYPE_MIN_P,
|
COMMON_SAMPLER_TYPE_MIN_P,
|
||||||
COMMON_SAMPLER_TYPE_XTC,
|
COMMON_SAMPLER_TYPE_XTC,
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
@ -199,6 +199,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||||
break;
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||||
|
// llama_sampler_chain_add(result->chain, )
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma))
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
@ -407,6 +411,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
|
||||||
default : return '?';
|
default : return '?';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -422,6 +427,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
|
||||||
default : return "";
|
default : return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -437,6 +443,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
|
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||||
};
|
};
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
@ -451,6 +458,9 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
|
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||||
|
{ "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||||
|
{ "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
std::vector<common_sampler_type> samplers;
|
||||||
@ -484,6 +494,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
std::vector<common_sampler_type> samplers;
|
||||||
|
@ -1645,6 +1645,48 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// top-n-sigma
|
||||||
|
|
||||||
|
struct llama_sampler_top_n_sigma {
|
||||||
|
const int32_t n;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "top-n-sigma";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
|
llama_sampler_top_n_sigma_impl(cur_p, ctx->n);
|
||||||
|
}
|
||||||
|
|
||||||
|
// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
|
||||||
|
// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
|
||||||
|
// return llama_sampler_init_top_k(ctx->k);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
||||||
|
// delete (llama_sampler_top_k *) smpl->ctx;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// static struct llama_sampler_i llama_sampler_top_k_i = {
|
||||||
|
// /* .name = */ llama_sampler_top_k_name,
|
||||||
|
// /* .accept = */ nullptr,
|
||||||
|
// /* .apply = */ llama_sampler_top_k_apply,
|
||||||
|
// /* .reset = */ nullptr,
|
||||||
|
// /* .clone = */ llama_sampler_top_k_clone,
|
||||||
|
// /* .free = */ llama_sampler_top_k_free,
|
||||||
|
// };
|
||||||
|
|
||||||
|
// struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||||
|
// return new llama_sampler {
|
||||||
|
// /* .iface = */ &llama_sampler_top_k_i,
|
||||||
|
// /* .ctx = */ new llama_sampler_top_k {
|
||||||
|
// /* .k = */ k,
|
||||||
|
// },
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
// DRY
|
// DRY
|
||||||
|
|
||||||
struct llama_sampler_dry {
|
struct llama_sampler_dry {
|
||||||
|
Loading…
Reference in New Issue
Block a user