mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
convert : add BertForMaskedLM (#10919)
This commit is contained in:
parent
a91a41364b
commit
5cd85b5e00
@ -2628,7 +2628,7 @@ class InternLM2Model(Model):
|
|||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
@Model.register("BertModel", "CamembertModel")
|
@Model.register("BertModel", "BertForMaskedLM", "CamembertModel")
|
||||||
class BertModel(Model):
|
class BertModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.BERT
|
model_arch = gguf.MODEL_ARCH.BERT
|
||||||
|
|
||||||
@ -2694,10 +2694,25 @@ class BertModel(Model):
|
|||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
|
||||||
|
if name.startswith("bert."):
|
||||||
|
name = name[5:]
|
||||||
|
|
||||||
|
if name.endswith(".gamma"):
|
||||||
|
name = name[:-6] + ".weight"
|
||||||
|
|
||||||
|
if name.endswith(".beta"):
|
||||||
|
name = name[:-5] + ".bias"
|
||||||
|
|
||||||
# we are only using BERT for embeddings so we don't need the pooling layer
|
# we are only using BERT for embeddings so we don't need the pooling layer
|
||||||
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
||||||
return [] # we don't need these
|
return [] # we don't need these
|
||||||
|
|
||||||
|
if name.startswith("cls.predictions"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
if name.startswith("cls.seq_relationship"):
|
||||||
|
return []
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user