From 82e4f64578dc3db40185e6b91195e73c9e995952 Mon Sep 17 00:00:00 2001 From: Radek Pilar Date: Tue, 12 Dec 2023 20:04:10 +0100 Subject: [PATCH] convert-hf : support for mixtral-instruct (#4428) * convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct * convert : use sentencepiece tokenizer for Mixtral-instruct * convert : make flake8 happy --- convert-hf-to-gguf.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index bced1f561..e46a7813a 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -77,8 +77,18 @@ class Model: self.gguf_writer.add_embedding_length(n_embd) if (n_ff := self.hparams.get("intermediate_size")) is not None: self.gguf_writer.add_feed_forward_length(n_ff) - if (n_head := self.hparams.get("num_attention_head")) is not None: + if (n_head := self.hparams.get("num_attention_heads")) is not None: self.gguf_writer.add_head_count(n_head) + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + + if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps) + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) def write_tensors(self): @@ -170,6 +180,8 @@ class Model: return StableLMModel if model_architecture == "QWenLMHeadModel": return QwenModel + if model_architecture == "MixtralForCausalLM": + return MixtralModel return Model def _is_model_safetensors(self) -> bool: @@ -207,6 +219,8 @@ class Model: return gguf.MODEL_ARCH.STABLELM if arch == "QWenLMHeadModel": return gguf.MODEL_ARCH.QWEN + if arch == "MixtralForCausalLM": + return gguf.MODEL_ARCH.LLAMA raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -837,6 +851,11 @@ class StableLMModel(Model): self.gguf_writer.add_layer_norm_eps(1e-5) +class MixtralModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + class QwenModel(Model): @staticmethod def token_bytes_to_string(b):