llama : fix tensor name grepping during quantization

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-08-17 21:40:51 +03:00
parent 57eaadb853
commit 5484737d58
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3432,6 +3432,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const std::string name = ggml_get_name(meta);
// TODO: avoid hardcoded tensor names - use the TN_* constants
if (name.find("attn_v.weight") != std::string::npos) {
++n_attention_wv;
}
@ -3510,6 +3511,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
} else {
new_type = quantized_type;
#ifdef GGML_USE_K_QUANTS
// TODO: avoid hardcoded tensor names - use the TN_* constants
if (name == TN_OUTPUT) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
@ -3524,7 +3526,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
++i_attention_wv;
} else if (name.find("feed_forward.w2.weight") != std::string::npos) {
} else if (name.find("ffn_down.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&