convert : add BertForMaskedLM (#10919)

This commit is contained in:
Georgi Gerganov 2024-12-21 10:10:18 +02:00 committed by GitHub
parent a91a41364b
commit 5cd85b5e00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2628,7 +2628,7 @@ class InternLM2Model(Model):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@Model.register("BertModel", "CamembertModel") @Model.register("BertModel", "BertForMaskedLM", "CamembertModel")
class BertModel(Model): class BertModel(Model):
model_arch = gguf.MODEL_ARCH.BERT model_arch = gguf.MODEL_ARCH.BERT
@ -2694,10 +2694,25 @@ class BertModel(Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
if name.startswith("bert."):
name = name[5:]
if name.endswith(".gamma"):
name = name[:-6] + ".weight"
if name.endswith(".beta"):
name = name[:-5] + ".bias"
# we are only using BERT for embeddings so we don't need the pooling layer # we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
return [] # we don't need these return [] # we don't need these
if name.startswith("cls.predictions"):
return []
if name.startswith("cls.seq_relationship"):
return []
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]