diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 038cf58dd..466e7bc61 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -7,14 +7,12 @@ #include #include #include +#include #include #include #include #include -// TODO: replace with ggml API call -#define QK_K 256 - static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -154,8 +152,10 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { new_type = qs.params->output_tensor_type; } else { - int nx = tensor->ne[0]; - if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) { + const int64_t nx = tensor->ne[0]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) { new_type = GGML_TYPE_Q8_0; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || @@ -367,20 +367,19 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; //} bool convert_incompatible_tensor = false; - if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || - new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || - new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || - new_type == GGML_TYPE_IQ1_M) { - int nx = tensor->ne[0]; - int ny = tensor->ne[1]; - if (nx % QK_K != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); + { + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); convert_incompatible_tensor = true; } else { ++qs.n_k_quantized; } } + if (convert_incompatible_tensor) { switch (new_type) { case GGML_TYPE_TQ1_0: