convert_hf : fix Jamba conversion

This commit is contained in:
Francis Couture-Harpin 2024-09-01 21:46:27 -04:00
parent a03e32a3c9
commit 9d3f44dad4

View File

@ -2910,7 +2910,6 @@ class JambaModel(Model):
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
] ]
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_embedding_length(d_model)
@ -2979,8 +2978,8 @@ class JambaModel(Model):
yield new_name, data_torch yield new_name, data_torch
def write_tensors(self): def prepare_tensors(self):
super().write_tensors() super().prepare_tensors()
if self._experts is not None: if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]` # flatten `list[dict[str, Tensor]]` into `list[str]`
@ -2988,20 +2987,6 @@ class JambaModel(Model):
if len(experts) > 0: if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
# same as Mamba
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del n_dims # unused
return bid is not None and new_name in (
self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
)
@Model.register("CohereForCausalLM") @Model.register("CohereForCausalLM")
class CommandR2Model(Model): class CommandR2Model(Model):