gguf : simplify gguf_get_val

This commit is contained in:
Georgi Gerganov 2023-07-26 18:53:57 +03:00
parent cb871fa022
commit d313c0fa33
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 25 additions and 63 deletions

68
ggml.c
View File

@ -18297,6 +18297,19 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
////////////////////////////////////////////////////////////////////////////////
enum gguf_type {
GGUF_TYPE_UINT8 = 0,
GGUF_TYPE_INT8 = 1,
GGUF_TYPE_UINT16 = 2,
GGUF_TYPE_INT16 = 3,
GGUF_TYPE_UINT32 = 4,
GGUF_TYPE_INT32 = 5,
GGUF_TYPE_FLOAT32 = 6,
GGUF_TYPE_BOOL = 7,
GGUF_TYPE_STRING = 8,
GGUF_TYPE_ARRAY = 9,
};
struct gguf_str {
uint32_t n;
char * data;
@ -18670,77 +18683,40 @@ enum gguf_type gguf_get_type(struct gguf_context * ctx, int i) {
return ctx->header.kv[i].type;
}
void gguf_get_val(struct gguf_context * ctx, int i, void * val) {
struct gguf_kv * kv = &ctx->header.kv[i];
switch (kv->type) {
case GGUF_TYPE_UINT8: memcpy(val, &kv->value.uint8, sizeof(uint8_t)); break;
case GGUF_TYPE_INT8: memcpy(val, &kv->value.int8, sizeof(int8_t)); break;
case GGUF_TYPE_UINT16: memcpy(val, &kv->value.uint16, sizeof(uint16_t)); break;
case GGUF_TYPE_INT16: memcpy(val, &kv->value.int16, sizeof(int16_t)); break;
case GGUF_TYPE_UINT32: memcpy(val, &kv->value.uint32, sizeof(uint32_t)); break;
case GGUF_TYPE_INT32: memcpy(val, &kv->value.int32, sizeof(int32_t)); break;
case GGUF_TYPE_FLOAT32: memcpy(val, &kv->value.float32, sizeof(float)); break;
case GGUF_TYPE_BOOL: memcpy(val, &kv->value.bool_, sizeof(bool)); break;
case GGUF_TYPE_STRING: memcpy(val, &kv->value.str.data, sizeof(char *)); break;
default:
GGML_ASSERT("gguf: not implemented");
break;
}
}
uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) {
uint8_t val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.uint8;
}
int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) {
int8_t val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.int8;
}
uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) {
uint16_t val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.uint16;
}
int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) {
int16_t val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.int16;
}
uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) {
uint32_t val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.uint32;
}
int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) {
int32_t val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.int32;
}
float gguf_get_val_f32(struct gguf_context * ctx, int i) {
float val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.float32;
}
bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
bool val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.bool_;
}
const char * gguf_get_val_str (struct gguf_context * ctx, int i) {
char * val;
gguf_get_val(ctx, i, &val);
return val;
return ctx->header.kv[i].value.str.data;
}
int gguf_get_n_tensors(struct gguf_context * ctx) {

20
ggml.h
View File

@ -1619,19 +1619,6 @@ extern "C" {
// gguf
//
enum gguf_type {
GGUF_TYPE_UINT8 = 0,
GGUF_TYPE_INT8 = 1,
GGUF_TYPE_UINT16 = 2,
GGUF_TYPE_INT16 = 3,
GGUF_TYPE_UINT32 = 4,
GGUF_TYPE_INT32 = 5,
GGUF_TYPE_FLOAT32 = 6,
GGUF_TYPE_BOOL = 7,
GGUF_TYPE_STRING = 8,
GGUF_TYPE_ARRAY = 9,
};
struct gguf_context;
struct gguf_init_params {
@ -1651,10 +1638,9 @@ extern "C" {
GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
GGML_API void * gguf_get_data (struct gguf_context * ctx);
GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_type(struct gguf_context * ctx, int i);
GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val);
GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val);
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);