mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
llama : fix Roberta embeddings (#10856)
* fix: Use gpt2 tokenizer for roberta and add eos/bos tokens Branch: RobertaTokenizer Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fixes to position embeddings Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * map roberta-bpe to gpt-2 Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix linting Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> Co-authored-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
7585edbdeb
commit
2fffc52b50
@ -2628,7 +2628,7 @@ class InternLM2Model(Model):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("BertModel", "CamembertModel", "RobertaModel")
|
||||
@Model.register("BertModel", "CamembertModel")
|
||||
class BertModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
|
||||
@ -2701,6 +2701,51 @@ class BertModel(Model):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("RobertaModel")
|
||||
class RobertaModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# we need the pad_token_id to know how to chop down position_embd matrix
|
||||
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
|
||||
self._position_offset = 1 + pad_token_id
|
||||
if "max_position_embeddings" in self.hparams:
|
||||
self.hparams["max_position_embeddings"] -= self._position_offset
|
||||
else:
|
||||
self._position_offset = None
|
||||
|
||||
def set_vocab(self):
|
||||
"""Support BPE tokenizers for roberta models"""
|
||||
bpe_tok_path = self.dir_model / "tokenizer.json"
|
||||
if bpe_tok_path.exists():
|
||||
self._set_vocab_gpt2()
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
# we need this to validate the size of the token_type embeddings
|
||||
# though currently we are passing all zeros to the token_type embeddings
|
||||
# "Sequence A" or "Sequence B"
|
||||
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
|
||||
|
||||
else:
|
||||
return super().set_vocab()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# if name starts with "roberta.", remove the prefix
|
||||
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
|
||||
if name.startswith("roberta."):
|
||||
name = name[8:]
|
||||
|
||||
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
|
||||
if name == "embeddings.position_embeddings.weight":
|
||||
if self._position_offset is not None:
|
||||
data_torch = data_torch[self._position_offset:,:]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@Model.register("NomicBertModel")
|
||||
class NomicBertModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
|
||||
|
@ -6592,7 +6592,8 @@ static void llm_load_vocab(
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
tokenizer_pre == "jina-v2-es" ||
|
||||
tokenizer_pre == "jina-v2-de" ||
|
||||
tokenizer_pre == "jina-v2-code") {
|
||||
tokenizer_pre == "jina-v2-code" ||
|
||||
tokenizer_pre == "roberta-bpe") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
||||
} else if (
|
||||
tokenizer_pre == "refact") {
|
||||
|
Loading…
Reference in New Issue
Block a user