mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-29 13:24:50 +01:00
9b hf chat support
This commit is contained in:
parent
9f5d80923e
commit
1099ef271e
@ -4622,53 +4622,12 @@ class ChatGLMModel(Model):
|
|||||||
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
|
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
|
||||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||||
|
|
||||||
if(hparams["partial_rotary_factor"] == 1.0):
|
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||||
# only for glm-edge series
|
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
self.gguf_writer.add_token_list(tokens)
|
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
|
||||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
|
||||||
else:
|
|
||||||
# for glm4 series
|
|
||||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
|
||||||
merges = []
|
|
||||||
vocab = {}
|
|
||||||
mergeable_ranks = tokenizer._mergeable_ranks
|
|
||||||
for token, rank in mergeable_ranks.items():
|
|
||||||
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
|
|
||||||
if len(token) == 1:
|
|
||||||
continue
|
|
||||||
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
|
|
||||||
assert len(merged) >= 2 and len(merged) <= 7
|
|
||||||
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
|
|
||||||
|
|
||||||
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
|
|
||||||
added_vocab = tokenizer.get_added_vocab()
|
|
||||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
|
|
||||||
|
|
||||||
for i in range(vocab_size):
|
|
||||||
if i not in reverse_vocab:
|
|
||||||
tokens.append(f"[PAD{i}]")
|
|
||||||
toktypes.append(gguf.TokenType.UNUSED)
|
|
||||||
elif reverse_vocab[i] in added_vocab:
|
|
||||||
tokens.append(reverse_vocab[i])
|
|
||||||
if tokenizer.added_tokens_decoder[i].special:
|
|
||||||
toktypes.append(gguf.TokenType.CONTROL)
|
|
||||||
else:
|
|
||||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
|
||||||
else:
|
|
||||||
tokens.append(reverse_vocab[i])
|
|
||||||
toktypes.append(gguf.TokenType.NORMAL)
|
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
|
||||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
|
||||||
self.gguf_writer.add_token_list(tokens)
|
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
|
||||||
|
|
||||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
|
|
||||||
special_vocab.merges = merges
|
|
||||||
# only add special tokens when they were not already loaded from config.json
|
# only add special tokens when they were not already loaded from config.json
|
||||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||||
|
@ -3085,6 +3085,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||||
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
}
|
}
|
||||||
|
|
||||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||||
|
@ -7215,13 +7215,25 @@ struct llm_build_context {
|
|||||||
struct ggml_tensor * Qcur = nullptr;
|
struct ggml_tensor * Qcur = nullptr;
|
||||||
struct ggml_tensor * Kcur = nullptr;
|
struct ggml_tensor * Kcur = nullptr;
|
||||||
struct ggml_tensor * Vcur = nullptr;
|
struct ggml_tensor * Vcur = nullptr;
|
||||||
if(model.type == LLM_TYPE_1_5B|| model.type == LLM_TYPE_4B){
|
if(model.type == LLM_TYPE_1_5B|| model.type == LLM_TYPE_4B || model.type == LLM_TYPE_9B){
|
||||||
Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
}else{
|
}else{
|
||||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
||||||
cb(cur, "wqkv", il);
|
cb(cur, "wqkv", il);
|
||||||
|
Loading…
Reference in New Issue
Block a user