From e3a49609533646cbd1990f2d99668aef7a0cc420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Yusuf=20Sar=C4=B1g=C3=B6z?= Date: Fri, 11 Aug 2023 13:03:23 +0300 Subject: [PATCH] gguf : add gguf_get_kv_type --- ggml.c | 2 +- ggml.h | 10 +++++----- gguf-llama.cpp | 20 +++++++++++++++++++- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/ggml.c b/ggml.c index ac45f12b0..e00f09fa4 100644 --- a/ggml.c +++ b/ggml.c @@ -19031,7 +19031,7 @@ const char * gguf_get_key(struct gguf_context * ctx, int i) { return ctx->header.kv[i].key.data; } -const enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) { +enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) { return ctx->header.kv[i].type; } diff --git a/ggml.h b/ggml.h index 4490e076c..9a266e175 100644 --- a/ggml.h +++ b/ggml.h @@ -1744,11 +1744,11 @@ extern "C" { GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx); GGML_API void * gguf_get_data (struct gguf_context * ctx); - GGML_API int gguf_get_n_kv(struct gguf_context * ctx); - GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key); - GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i); - GGML_API const enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); - GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val); + GGML_API int gguf_get_n_kv(struct gguf_context * ctx); + GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key); + GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i); + GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); + GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val); GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i); GGML_API float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i); diff --git a/gguf-llama.cpp b/gguf-llama.cpp index 8b928c364..9ab770277 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -536,6 +536,7 @@ struct ggml_context * ctx_data = NULL; hparams.n_ctx = read_u32("llama.context_length"); hparams.n_embd = read_u32("llama.embedding_length"); uint32_t n_ff = read_u32("llama.feed_forward_length"); + GGML_UNUSED(n_ff); //hparams.n_mult = find_n_mult(n_ff, hparams.n_embd); hparams.n_head = read_u32("llama.attention.head_count"); hparams.n_layer = read_u32("llama.layer_count"); @@ -654,7 +655,21 @@ struct gguf_file_saver { file.write_val("general.quantization_version", GGUF_TYPE_UINT32, new_ftype); } else { const gguf_type vtype = gguf_get_kv_type(any_file_loader->gguf_ctx, i); - GGML_UNUSED(vtype); + switch(vtype) { + case GGUF_TYPE_BOOL: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_INT16: + case GGUF_TYPE_INT32: + case GGUF_TYPE_INT8: + case GGUF_TYPE_STRING: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_UINT8: + case GGUF_TYPE_ARRAY: + break; + default: + throw std::runtime_error(format("cannot recognize value type for key %s\n", key)); + } } } @@ -3873,6 +3888,9 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { gguf_file file(path_session, "wb"); + GGML_UNUSED(ctx); + GGML_UNUSED(tokens); + GGML_UNUSED(n_token_count); // TODO: implement with GGUF format