gguf : add gguf_get_kv_type

This commit is contained in:
M. Yusuf Sarıgöz 2023-08-11 13:03:23 +03:00
parent eb8ca6996f
commit e3a4960953
3 changed files with 25 additions and 7 deletions

2
ggml.c
View File

@ -19031,7 +19031,7 @@ const char * gguf_get_key(struct gguf_context * ctx, int i) {
return ctx->header.kv[i].key.data; return ctx->header.kv[i].key.data;
} }
const enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) { enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
return ctx->header.kv[i].type; return ctx->header.kv[i].type;
} }

10
ggml.h
View File

@ -1744,11 +1744,11 @@ extern "C" {
GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx); GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
GGML_API void * gguf_get_data (struct gguf_context * ctx); GGML_API void * gguf_get_data (struct gguf_context * ctx);
GGML_API int gguf_get_n_kv(struct gguf_context * ctx); GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key); GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key);
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i); GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
GGML_API const enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val); GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val);
GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i); GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i);
GGML_API float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i); GGML_API float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i);

View File

@ -536,6 +536,7 @@ struct ggml_context * ctx_data = NULL;
hparams.n_ctx = read_u32("llama.context_length"); hparams.n_ctx = read_u32("llama.context_length");
hparams.n_embd = read_u32("llama.embedding_length"); hparams.n_embd = read_u32("llama.embedding_length");
uint32_t n_ff = read_u32("llama.feed_forward_length"); uint32_t n_ff = read_u32("llama.feed_forward_length");
GGML_UNUSED(n_ff);
//hparams.n_mult = find_n_mult(n_ff, hparams.n_embd); //hparams.n_mult = find_n_mult(n_ff, hparams.n_embd);
hparams.n_head = read_u32("llama.attention.head_count"); hparams.n_head = read_u32("llama.attention.head_count");
hparams.n_layer = read_u32("llama.layer_count"); hparams.n_layer = read_u32("llama.layer_count");
@ -654,7 +655,21 @@ struct gguf_file_saver {
file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, new_ftype); file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, new_ftype);
} else { } else {
const gguf_type vtype = gguf_get_kv_type(any_file_loader->gguf_ctx, i); const gguf_type vtype = gguf_get_kv_type(any_file_loader->gguf_ctx, i);
GGML_UNUSED(vtype); switch(vtype) {
case GGUF_TYPE_BOOL:
case GGUF_TYPE_FLOAT32:
case GGUF_TYPE_INT16:
case GGUF_TYPE_INT32:
case GGUF_TYPE_INT8:
case GGUF_TYPE_STRING:
case GGUF_TYPE_UINT16:
case GGUF_TYPE_UINT32:
case GGUF_TYPE_UINT8:
case GGUF_TYPE_ARRAY:
break;
default:
throw std::runtime_error(format("cannot recognize value type for key %s\n", key));
}
} }
} }
@ -3873,6 +3888,9 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
gguf_file file(path_session, "wb"); gguf_file file(path_session, "wb");
GGML_UNUSED(ctx);
GGML_UNUSED(tokens);
GGML_UNUSED(n_token_count);
// TODO: implement with GGUF format // TODO: implement with GGUF format