diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 59ed15504..01d5ca57f 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -304,3 +304,65 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 batch.logits = logits.data(); } } + +// +// interface implementation +// + +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; +} + +struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens_alloc] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i] != nullptr; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} diff --git a/src/llama-batch.h b/src/llama-batch.h index 69e379a2e..773c3808b 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -86,4 +86,3 @@ struct llama_batch_allocr { // optionally fulfill the batch returned by llama_batch_get_one llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); }; - diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp index fc3fee213..f8659796f 100644 --- a/src/llama-impl.cpp +++ b/src/llama-impl.cpp @@ -2,7 +2,9 @@ #include "llama.h" +#include #include +#include struct llama_logger_state { ggml_log_callback log_callback = llama_log_callback_default; @@ -19,23 +21,6 @@ time_meas::~time_meas() { } } -void replace_all(std::string & s, const std::string & search, const std::string & replace) { - if (search.empty()) { - return; - } - std::string builder; - builder.reserve(s.length()); - size_t pos = 0; - size_t last_pos = 0; - while ((pos = s.find(search, last_pos)) != std::string::npos) { - builder.append(s, last_pos, pos - last_pos); - builder.append(replace); - last_pos = pos + search.length(); - } - builder.append(s, last_pos, std::string::npos); - s = std::move(builder); -} - void llama_log_set(ggml_log_callback log_callback, void * user_data) { ggml_log_set(log_callback, user_data); g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default; @@ -72,3 +57,35 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * fputs(text, stderr); fflush(stderr); } + +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} diff --git a/src/llama-impl.h b/src/llama-impl.h index dbe5c21c5..fbf88039d 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,6 +1,6 @@ #pragma once -#include "ggml.h" +#include "ggml.h" // for ggml_log_level #include @@ -22,10 +22,6 @@ LLAMA_ATTRIBUTE_FORMAT(2, 3) void llama_log_internal (ggml_log_level level, const char * format, ...); void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); -// TODO: rename to llama_format ? -LLAMA_ATTRIBUTE_FORMAT(1, 2) -std::string format(const char * fmt, ...); - #define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__) #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) @@ -47,3 +43,7 @@ struct time_meas { }; void replace_all(std::string & s, const std::string & search, const std::string & replace); + +// TODO: rename to llama_format ? +LLAMA_ATTRIBUTE_FORMAT(1, 2) +std::string format(const char * fmt, ...); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f4a72bebc..dea982cc2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -189,3 +189,225 @@ struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, co return it->second; } + +size_t llama_model_max_nodes(const llama_model & model) { + return std::max(8192, model.tensors_by_name.size()*5); +} + +// +// interface implementation +// + +struct llama_model_params llama_model_default_params() { + struct llama_model_params result = { + /*.devices =*/ nullptr, + /*.n_gpu_layers =*/ 0, + /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ nullptr, + /*.rpc_servers =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.kv_overrides =*/ nullptr, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + /*.check_tensors =*/ false, + }; + +#ifdef GGML_USE_METAL + // note: we usually have plenty of VRAM, so by default offload all layers to the GPU + result.n_gpu_layers = 999; +#endif + + return result; +} + +void llama_free_model(struct llama_model * model) { + delete model; +} + +enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { + return model->vocab.type; +} + +int32_t llama_n_vocab(const struct llama_model * model) { + return model->hparams.n_vocab; +} + +int32_t llama_n_ctx_train(const struct llama_model * model) { + return model->hparams.n_ctx_train; +} + +int32_t llama_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +int32_t llama_n_layer(const struct llama_model * model) { + return model->hparams.n_layer; +} + +int32_t llama_n_head(const struct llama_model * model) { + return model->hparams.n_head(); +} + +enum llama_rope_type llama_rope_type(const struct llama_model * model) { + switch (model->arch) { + // these models do not use RoPE + case LLM_ARCH_GPT2: + case LLM_ARCH_GPTJ: + case LLM_ARCH_MPT: + case LLM_ARCH_REFACT: + case LLM_ARCH_BLOOM: + case LLM_ARCH_MAMBA: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: + case LLM_ARCH_JAIS: + case LLM_ARCH_RWKV6: + case LLM_ARCH_WAVTOKENIZER_DEC: + return LLAMA_ROPE_TYPE_NONE; + + // use what we call a normal RoPE, operating on pairs of consecutive head values + case LLM_ARCH_LLAMA: + case LLM_ARCH_DECI: + case LLM_ARCH_BAICHUAN: + case LLM_ARCH_STARCODER: + case LLM_ARCH_PLAMO: + case LLM_ARCH_ORION: + case LLM_ARCH_INTERNLM2: + case LLM_ARCH_MINICPM: + case LLM_ARCH_XVERSE: + case LLM_ARCH_COMMAND_R: + case LLM_ARCH_OLMO: + case LLM_ARCH_ARCTIC: + case LLM_ARCH_DEEPSEEK: + case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_CHATGLM: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: + return LLAMA_ROPE_TYPE_NORM; + + // the pairs of head values are offset by n_rot/2 + case LLM_ARCH_FALCON: + case LLM_ARCH_GROK: + case LLM_ARCH_DBRX: + case LLM_ARCH_BERT: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_STABLELM: + case LLM_ARCH_BITNET: + case LLM_ARCH_QWEN: + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_OLMO2: + case LLM_ARCH_OLMOE: + case LLM_ARCH_PHI2: + case LLM_ARCH_PHI3: + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_OPENELM: + case LLM_ARCH_GPTNEOX: + case LLM_ARCH_CODESHELL: + case LLM_ARCH_NEMOTRON: + case LLM_ARCH_EXAONE: + case LLM_ARCH_MINICPM3: + return LLAMA_ROPE_TYPE_NEOX; + + case LLM_ARCH_QWEN2VL: + return LLAMA_ROPE_TYPE_MROPE; + + // all model arches should be listed explicitly here + case LLM_ARCH_UNKNOWN: + GGML_ABORT("unknown architecture"); + } + + return LLAMA_ROPE_TYPE_NONE; +} + +float llama_rope_freq_scale_train(const struct llama_model * model) { + return model->hparams.rope_freq_scale_train; +} + +int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) { + const auto & it = model->gguf_kv.find(key); + if (it == model->gguf_kv.end()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_model_meta_count(const struct llama_model * model) { + return (int)model->gguf_kv.size(); +} + +int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->first.c_str()); +} + +int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { + return snprintf(buf, buf_size, "%s %s %s", + llama_model_arch_name (*model).c_str(), + llama_model_type_name (*model).c_str(), + llama_model_ftype_name(*model).c_str()); +} + +uint64_t llama_model_size(const struct llama_model * model) { + return model->n_bytes; +} + +uint64_t llama_model_n_params(const struct llama_model * model) { + return model->n_elements; +} + +bool llama_model_has_encoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5: return true; + case LLM_ARCH_T5ENCODER: return true; + default: return false; + } +} + +bool llama_model_has_decoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5ENCODER: return false; + default: return true; + } +} + +llama_token llama_model_decoder_start_token(const struct llama_model * model) { + return model->hparams.dec_start_token_id; +} + +bool llama_model_is_recurrent(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_MAMBA: return true; + case LLM_ARCH_RWKV6: return true; + default: return false; + } +} + diff --git a/src/llama-model.h b/src/llama-model.h index 5123ac9a0..792f7cdca 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -377,3 +377,4 @@ ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, in // used by llama_adapter_lora struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name); +size_t llama_model_max_nodes(const llama_model & model); diff --git a/src/llama.cpp b/src/llama.cpp index e706d9343..ba3c0c74a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7,7 +7,7 @@ #include "llama-sampling.h" #include "llama-kv-cache.h" -#include "unicode.h" +#include "unicode.h" // TODO: remove #include "ggml.h" #include "ggml-alloc.h" @@ -40,7 +40,6 @@ #include #include #include -#include #include #include #include @@ -54,41 +53,6 @@ // helpers // -std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - -static bool is_float_close(float a, float b, float abs_tol) { - // Check for non-negative tolerance - if (abs_tol < 0.0) { - throw std::invalid_argument("Tolerance must be non-negative"); - } - - // Exact equality check - if (a == b) { - return true; - } - - // Check for infinities - if (std::isinf(a) || std::isinf(b)) { - return false; - } - - // Regular comparison using the provided absolute tolerance - return std::fabs(b - a) <= abs_tol; -} - static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -416,10 +380,6 @@ namespace GGUFMeta { using llama_buf_map = std::unordered_map; -static size_t llama_model_max_nodes(const llama_model & model) { - return std::max(8192, model.tensors_by_name.size()*5); -} - struct llama_model_loader { int n_kv = 0; int n_tensors = 0; @@ -2468,8 +2428,6 @@ static void llm_load_vocab( for (uint32_t i = 0; i < n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); - - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); if (word.empty()) { LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); word = "[EMPTY_" + std::to_string(i) + "]"; @@ -15674,31 +15632,6 @@ int32_t llama_control_vector_apply( // interface implementation // -struct llama_model_params llama_model_default_params() { - struct llama_model_params result = { - /*.devices =*/ nullptr, - /*.n_gpu_layers =*/ 0, - /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, - /*.main_gpu =*/ 0, - /*.tensor_split =*/ nullptr, - /*.rpc_servers =*/ nullptr, - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, - /*.kv_overrides =*/ nullptr, - /*.vocab_only =*/ false, - /*.use_mmap =*/ true, - /*.use_mlock =*/ false, - /*.check_tensors =*/ false, - }; - -#ifdef GGML_USE_METAL - // note: we usually have plenty of VRAM, so by default offload all layers to the GPU - result.n_gpu_layers = 999; -#endif - - return result; -} - struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.n_ctx =*/ 512, @@ -15825,7 +15758,7 @@ int64_t llama_time_us(void) { struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_model_params params) { + struct llama_model_params params) { ggml_time_init(); llama_model * model = new llama_model; @@ -15943,10 +15876,6 @@ struct llama_model * llama_load_model_from_file( return model; } -void llama_free_model(struct llama_model * model) { - delete model; -} - struct llama_context * llama_new_context_with_model( struct llama_model * model, struct llama_context_params params) { @@ -16318,30 +16247,6 @@ uint32_t llama_n_seq_max(const struct llama_context * ctx) { return ctx->kv_self.size; } -enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { - return model->vocab.type; -} - -int32_t llama_n_vocab(const struct llama_model * model) { - return model->hparams.n_vocab; -} - -int32_t llama_n_ctx_train(const struct llama_model * model) { - return model->hparams.n_ctx_train; -} - -int32_t llama_n_embd(const struct llama_model * model) { - return model->hparams.n_embd; -} - -int32_t llama_n_layer(const struct llama_model * model) { - return model->hparams.n_layer; -} - -int32_t llama_n_head(const struct llama_model * model) { - return model->hparams.n_head(); -} - const struct llama_model * llama_get_model(const struct llama_context * ctx) { return &ctx->model; } @@ -16350,166 +16255,6 @@ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { return ctx->cparams.pooling_type; } -enum llama_rope_type llama_rope_type(const struct llama_model * model) { - switch (model->arch) { - // these models do not use RoPE - case LLM_ARCH_GPT2: - case LLM_ARCH_GPTJ: - case LLM_ARCH_MPT: - case LLM_ARCH_REFACT: - case LLM_ARCH_BLOOM: - case LLM_ARCH_MAMBA: - case LLM_ARCH_JINA_BERT_V2: - case LLM_ARCH_T5: - case LLM_ARCH_T5ENCODER: - case LLM_ARCH_JAIS: - case LLM_ARCH_RWKV6: - case LLM_ARCH_WAVTOKENIZER_DEC: - return LLAMA_ROPE_TYPE_NONE; - - // use what we call a normal RoPE, operating on pairs of consecutive head values - case LLM_ARCH_LLAMA: - case LLM_ARCH_DECI: - case LLM_ARCH_BAICHUAN: - case LLM_ARCH_STARCODER: - case LLM_ARCH_PLAMO: - case LLM_ARCH_ORION: - case LLM_ARCH_INTERNLM2: - case LLM_ARCH_MINICPM: - case LLM_ARCH_XVERSE: - case LLM_ARCH_COMMAND_R: - case LLM_ARCH_OLMO: - case LLM_ARCH_ARCTIC: - case LLM_ARCH_DEEPSEEK: - case LLM_ARCH_DEEPSEEK2: - case LLM_ARCH_CHATGLM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_CHAMELEON: - return LLAMA_ROPE_TYPE_NORM; - - // the pairs of head values are offset by n_rot/2 - case LLM_ARCH_FALCON: - case LLM_ARCH_GROK: - case LLM_ARCH_DBRX: - case LLM_ARCH_BERT: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_STABLELM: - case LLM_ARCH_BITNET: - case LLM_ARCH_QWEN: - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2MOE: - case LLM_ARCH_OLMO2: - case LLM_ARCH_OLMOE: - case LLM_ARCH_PHI2: - case LLM_ARCH_PHI3: - case LLM_ARCH_GEMMA: - case LLM_ARCH_GEMMA2: - case LLM_ARCH_STARCODER2: - case LLM_ARCH_OPENELM: - case LLM_ARCH_GPTNEOX: - case LLM_ARCH_CODESHELL: - case LLM_ARCH_NEMOTRON: - case LLM_ARCH_EXAONE: - case LLM_ARCH_MINICPM3: - return LLAMA_ROPE_TYPE_NEOX; - - case LLM_ARCH_QWEN2VL: - return LLAMA_ROPE_TYPE_MROPE; - - // all model arches should be listed explicitly here - case LLM_ARCH_UNKNOWN: - GGML_ABORT("unknown architecture"); - } - - return LLAMA_ROPE_TYPE_NONE; -} - -float llama_rope_freq_scale_train(const struct llama_model * model) { - return model->hparams.rope_freq_scale_train; -} - -int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) { - const auto & it = model->gguf_kv.find(key); - if (it == model->gguf_kv.end()) { - if (buf_size > 0) { - buf[0] = '\0'; - } - return -1; - } - return snprintf(buf, buf_size, "%s", it->second.c_str()); -} - -int32_t llama_model_meta_count(const struct llama_model * model) { - return (int)model->gguf_kv.size(); -} - -int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) { - if (i < 0 || i >= (int)model->gguf_kv.size()) { - if (buf_size > 0) { - buf[0] = '\0'; - } - return -1; - } - auto it = model->gguf_kv.begin(); - std::advance(it, i); - return snprintf(buf, buf_size, "%s", it->first.c_str()); -} - -int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) { - if (i < 0 || i >= (int)model->gguf_kv.size()) { - if (buf_size > 0) { - buf[0] = '\0'; - } - return -1; - } - auto it = model->gguf_kv.begin(); - std::advance(it, i); - return snprintf(buf, buf_size, "%s", it->second.c_str()); -} - -int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { - return snprintf(buf, buf_size, "%s %s %s", - llama_model_arch_name (*model).c_str(), - llama_model_type_name (*model).c_str(), - llama_model_ftype_name(*model).c_str()); -} - -uint64_t llama_model_size(const struct llama_model * model) { - return model->n_bytes; -} - -uint64_t llama_model_n_params(const struct llama_model * model) { - return model->n_elements; -} - -bool llama_model_has_encoder(const struct llama_model * model) { - switch (model->arch) { - case LLM_ARCH_T5: return true; - case LLM_ARCH_T5ENCODER: return true; - default: return false; - } -} - -bool llama_model_has_decoder(const struct llama_model * model) { - switch (model->arch) { - case LLM_ARCH_T5ENCODER: return false; - default: return true; - } -} - -llama_token llama_model_decoder_start_token(const struct llama_model * model) { - return model->hparams.dec_start_token_id; -} - -bool llama_model_is_recurrent(const struct llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - default: return false; - } -} - uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, @@ -16523,7 +16268,11 @@ uint32_t llama_model_quantize( } } -/// +// +// kv cache +// + +// TODO: tmp bridges below until `struct llama_kv_cache` is exposed through the public API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) { return llama_kv_cache_view_init(ctx->kv_self, n_seq_max); @@ -16628,64 +16377,6 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { ctx->cparams.causal_attn = causal_attn; } -struct llama_batch llama_batch_get_one( - llama_token * tokens, - int32_t n_tokens) { - return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, - }; -} - -struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, - }; - - if (embd) { - batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); - } else { - batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); - } - - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); - batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); - batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); - for (int i = 0; i < n_tokens_alloc; ++i) { - batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch.seq_id[n_tokens_alloc] = nullptr; - - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); - - return batch; -} - -void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.n_seq_id) free(batch.n_seq_id); - if (batch.seq_id) { - for (int i = 0; batch.seq_id[i] != nullptr; ++i) { - free(batch.seq_id[i]); - } - free(batch.seq_id); - } - if (batch.logits) free(batch.logits); -} - int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { @@ -16852,6 +16543,8 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id // vocab // +// TODO: tmp bridges below until `struct llama_vocab` is exposed through the public API + const char * llama_token_get_text(const struct llama_model * model, llama_token token) { return llama_token_get_text_impl(model->vocab, token); }