llama : expose model's rope_freq_scale in the API (#3418)

so it can be scaled further before creating a context.
This commit is contained in:
Alex Klinkhamer 2023-10-03 10:09:28 -07:00 committed by GitHub
parent f56e1baec3
commit 48be797ffb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 0 deletions

View File

@ -7038,6 +7038,10 @@ int llama_n_embd(const struct llama_model * model) {
return model->hparams.n_embd; return model->hparams.n_embd;
} }
float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train;
}
int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "%s %s %s", return snprintf(buf, buf_size, "%s %s %s",
llama_model_arch_name(model->arch).c_str(), llama_model_arch_name(model->arch).c_str(),

View File

@ -282,6 +282,9 @@ extern "C" {
LLAMA_API int llama_n_ctx_train(const struct llama_model * model); LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int llama_n_embd (const struct llama_model * model); LLAMA_API int llama_n_embd (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
// Get a string describing the model type // Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);