mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
llama : add reranking support (#9510)
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
1b2f992cd2
commit
f4d2b8846a
85
ci/run.sh
85
ci/run.sh
@ -712,6 +712,81 @@ function gg_run_embd_bge_small {
|
|||||||
set +e
|
set +e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function gg_sum_embd_bge_small {
|
||||||
|
gg_printf '### %s\n\n' "${ci}"
|
||||||
|
|
||||||
|
gg_printf 'BGE Small (BERT):\n'
|
||||||
|
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
|
||||||
|
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
|
||||||
|
gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
|
||||||
|
}
|
||||||
|
|
||||||
|
# rerank_tiny
|
||||||
|
|
||||||
|
function gg_run_rerank_tiny {
|
||||||
|
cd ${SRC}
|
||||||
|
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json
|
||||||
|
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
|
||||||
|
|
||||||
|
gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json
|
||||||
|
|
||||||
|
path_models="../models-mnt/rerank-tiny"
|
||||||
|
|
||||||
|
rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
(time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||||
|
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
|
||||||
|
|
||||||
|
python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf
|
||||||
|
|
||||||
|
model_f16="${path_models}/ggml-model-f16.gguf"
|
||||||
|
|
||||||
|
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s><s>hi\nwhat is panda?</s><s>it's a bear\nwhat is panda?</s><s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||||
|
|
||||||
|
# sample output
|
||||||
|
# rerank score 0: 0.029
|
||||||
|
# rerank score 1: 0.029
|
||||||
|
# rerank score 2: 0.135
|
||||||
|
|
||||||
|
# check that the score is in the range [$3, $4]
|
||||||
|
function check_score {
|
||||||
|
qnt="$1"
|
||||||
|
score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
|
||||||
|
|
||||||
|
if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then
|
||||||
|
printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4"
|
||||||
|
return 20
|
||||||
|
fi
|
||||||
|
|
||||||
|
printf ' - %s @ %s OK\n' "$qnt" "$score"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
|
||||||
|
check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
|
||||||
|
check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log
|
||||||
|
|
||||||
|
set +e
|
||||||
|
}
|
||||||
|
|
||||||
|
function gg_sum_rerank_tiny {
|
||||||
|
gg_printf '### %s\n\n' "${ci}"
|
||||||
|
|
||||||
|
gg_printf 'Rerank Tiny (Jina):\n'
|
||||||
|
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
|
||||||
|
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)"
|
||||||
|
}
|
||||||
|
|
||||||
function gg_check_build_requirements {
|
function gg_check_build_requirements {
|
||||||
if ! command -v cmake &> /dev/null; then
|
if ! command -v cmake &> /dev/null; then
|
||||||
gg_printf 'cmake not found, please install'
|
gg_printf 'cmake not found, please install'
|
||||||
@ -726,15 +801,6 @@ function gg_check_build_requirements {
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
function gg_sum_embd_bge_small {
|
|
||||||
gg_printf '### %s\n\n' "${ci}"
|
|
||||||
|
|
||||||
gg_printf 'BGE Small (BERT):\n'
|
|
||||||
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
|
|
||||||
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
|
|
||||||
gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
|
|
||||||
}
|
|
||||||
|
|
||||||
## main
|
## main
|
||||||
|
|
||||||
export LLAMA_LOG_PREFIX=1
|
export LLAMA_LOG_PREFIX=1
|
||||||
@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release
|
|||||||
|
|
||||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||||
test $ret -eq 0 && gg_run embd_bge_small
|
test $ret -eq 0 && gg_run embd_bge_small
|
||||||
|
test $ret -eq 0 && gg_run rerank_tiny
|
||||||
|
|
||||||
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
|
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
|
||||||
test $ret -eq 0 && gg_run test_scripts_debug
|
test $ret -eq 0 && gg_run test_scripts_debug
|
||||||
|
@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
|
|||||||
params.kv_overrides.back().key[0] = 0;
|
params.kv_overrides.back().key[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.reranking && params.embedding) {
|
||||||
|
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||||||
[](gpt_params & params) {
|
[](gpt_params & params) {
|
||||||
params.verbose_prompt = true;
|
params.verbose_prompt = true;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--no-display-prompt"},
|
{"--no-display-prompt"},
|
||||||
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
|
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
|
||||||
@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--pooling"}, "{none,mean,cls,last}",
|
{"--pooling"}, "{none,mean,cls,last,rank}",
|
||||||
"pooling type for embeddings, use model default if unspecified",
|
"pooling type for embeddings, use model default if unspecified",
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||||
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
||||||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||||
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
||||||
|
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
|
||||||
else { throw std::invalid_argument("invalid value"); }
|
else { throw std::invalid_argument("invalid value"); }
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
||||||
@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||||||
params.embedding = true;
|
params.embedding = true;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
||||||
|
add_opt(llama_arg(
|
||||||
|
{"--reranking", "--rerank"},
|
||||||
|
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
|
||||||
|
[](gpt_params & params) {
|
||||||
|
params.reranking = true;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--api-key"}, "KEY",
|
{"--api-key"}, "KEY",
|
||||||
"API key to use for authentication (default: none)",
|
"API key to use for authentication (default: none)",
|
||||||
|
@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||||||
cparams.flash_attn = params.flash_attn;
|
cparams.flash_attn = params.flash_attn;
|
||||||
cparams.no_perf = params.no_perf;
|
cparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
|
if (params.reranking) {
|
||||||
|
cparams.embeddings = true;
|
||||||
|
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||||
|
}
|
||||||
|
|
||||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||||
|
|
||||||
|
@ -271,6 +271,7 @@ struct gpt_params {
|
|||||||
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||||
std::string embd_sep = "\n"; // separator of embendings
|
std::string embd_sep = "\n"; // separator of embendings
|
||||||
|
bool reranking = false; // enable reranking support on server
|
||||||
|
|
||||||
// server params
|
// server params
|
||||||
int32_t port = 8080; // server listens on this network port
|
int32_t port = 8080; // server listens on this network port
|
||||||
|
@ -291,8 +291,13 @@ class Model:
|
|||||||
bid = int(part)
|
bid = int(part)
|
||||||
break
|
break
|
||||||
|
|
||||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
|
||||||
data: np.ndarray # type hint
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
|
# if data ends up empty, it means data_torch was a scalar tensor -> restore
|
||||||
|
if len(data.shape) == 0:
|
||||||
|
data = data_torch.numpy()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
|
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
|
||||||
|
|
||||||
@ -592,6 +597,9 @@ class Model:
|
|||||||
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
|
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
|
||||||
# ref: https://huggingface.co/databricks/dbrx-base
|
# ref: https://huggingface.co/databricks/dbrx-base
|
||||||
res = "dbrx"
|
res = "dbrx"
|
||||||
|
if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448":
|
||||||
|
# ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||||
|
res = "jina-v1-en"
|
||||||
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
||||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
|
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
|
||||||
res = "jina-v2-en"
|
res = "jina-v2-en"
|
||||||
@ -2601,7 +2609,7 @@ class NomicBertModel(BertModel):
|
|||||||
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
||||||
|
|
||||||
|
|
||||||
@Model.register("XLMRobertaModel")
|
@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
|
||||||
class XLMRobertaModel(BertModel):
|
class XLMRobertaModel(BertModel):
|
||||||
model_arch = gguf.MODEL_ARCH.BERT
|
model_arch = gguf.MODEL_ARCH.BERT
|
||||||
|
|
||||||
@ -2699,6 +2707,11 @@ class XLMRobertaModel(BertModel):
|
|||||||
self.gguf_writer.add_add_eos_token(True)
|
self.gguf_writer.add_add_eos_token(True)
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# if name starts with "roberta.", remove the prefix
|
||||||
|
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
|
||||||
|
if name.startswith("roberta."):
|
||||||
|
name = name[8:]
|
||||||
|
|
||||||
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
|
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
|
||||||
if name == "embeddings.position_embeddings.weight":
|
if name == "embeddings.position_embeddings.weight":
|
||||||
if self._position_offset is not None:
|
if self._position_offset is not None:
|
||||||
@ -3110,6 +3123,14 @@ class JinaBertV2Model(BertModel):
|
|||||||
self.gguf_writer.add_add_bos_token(True)
|
self.gguf_writer.add_add_bos_token(True)
|
||||||
self.gguf_writer.add_add_eos_token(True)
|
self.gguf_writer.add_add_eos_token(True)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# if name starts with "bert.", remove the prefix
|
||||||
|
# e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||||
|
if name.startswith("bert."):
|
||||||
|
name = name[5:]
|
||||||
|
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
@Model.register("OpenELMForCausalLM")
|
@Model.register("OpenELMForCausalLM")
|
||||||
class OpenELMModel(Model):
|
class OpenELMModel(Model):
|
||||||
|
@ -81,6 +81,7 @@ models = [
|
|||||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||||
|
{"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
|
||||||
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||||
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
||||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
||||||
|
@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
|
|||||||
// tokenize the prompts and trim
|
// tokenize the prompts and trim
|
||||||
std::vector<std::vector<int32_t>> inputs;
|
std::vector<std::vector<int32_t>> inputs;
|
||||||
for (const auto & prompt : prompts) {
|
for (const auto & prompt : prompts) {
|
||||||
auto inp = ::llama_tokenize(ctx, prompt, true, false);
|
auto inp = ::llama_tokenize(ctx, prompt, true, true);
|
||||||
if (inp.size() > n_batch) {
|
if (inp.size() > n_batch) {
|
||||||
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
|
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
|
||||||
__func__, (long long int) inp.size(), (long long int) n_batch);
|
__func__, (long long int) inp.size(), (long long int) n_batch);
|
||||||
@ -234,6 +234,11 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
}
|
}
|
||||||
|
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
||||||
|
for (int j = 0; j < n_embd_count; j++) {
|
||||||
|
// NOTE: if you change this log - update the tests in ci/run.sh
|
||||||
|
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||||
for (int j = 0; j < n_prompts; j++) {
|
for (int j = 0; j < n_prompts; j++) {
|
||||||
|
@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
|
|||||||
**Features:**
|
**Features:**
|
||||||
* LLM inference of F16 and quantized models on GPU and CPU
|
* LLM inference of F16 and quantized models on GPU and CPU
|
||||||
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
|
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
|
||||||
|
* Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510)
|
||||||
* Parallel decoding with multi-user support
|
* Parallel decoding with multi-user support
|
||||||
* Continuous batching
|
* Continuous batching
|
||||||
* Multimodal (wip)
|
* Multimodal (wip)
|
||||||
@ -23,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| -------- | ----------- |
|
| -------- | ----------- |
|
||||||
| `-h, --help, --usage` | print usage and exit |
|
| `-h, --help, --usage` | print usage and exit |
|
||||||
| `--version` | show version and build info |
|
| `--version` | show version and build info |
|
||||||
|
| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
|
||||||
| `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
|
| `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
|
||||||
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
|
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
|
||||||
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
|
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
|
||||||
@ -130,7 +132,7 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
|
| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
|
||||||
| `-sp, --special` | special tokens output enabled (default: false) |
|
| `-sp, --special` | special tokens output enabled (default: false) |
|
||||||
| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
|
| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
|
||||||
| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
|
| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
|
||||||
| `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
|
| `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
|
||||||
| `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
|
| `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
|
||||||
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
|
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
|
||||||
@ -138,6 +140,7 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
|
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
|
||||||
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
|
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
|
||||||
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
|
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
|
||||||
|
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
|
||||||
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
|
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
|
||||||
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
|
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
|
||||||
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
|
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
|
||||||
@ -152,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
|
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
|
||||||
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
|
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
|
||||||
|
|
||||||
|
|
||||||
Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
|
Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
|
||||||
|
|
||||||
Example usage of docker compose with environment variables:
|
Example usage of docker compose with environment variables:
|
||||||
@ -478,6 +482,39 @@ The same as [the embedding example](../embedding) does.
|
|||||||
|
|
||||||
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
|
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
|
||||||
|
|
||||||
|
### POST `/reranking`: Rerank documents according to a given query
|
||||||
|
|
||||||
|
Similar to https://jina.ai/reranker/ but might change in the future.
|
||||||
|
Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options.
|
||||||
|
|
||||||
|
*Options:*
|
||||||
|
|
||||||
|
`query`: The query against which the documents will be ranked.
|
||||||
|
|
||||||
|
`documents`: An array strings representing the documents to be ranked.
|
||||||
|
|
||||||
|
*Aliases:*
|
||||||
|
- `/rerank`
|
||||||
|
- `/v1/rerank`
|
||||||
|
- `/v1/reranking`
|
||||||
|
|
||||||
|
*Examples:*
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://127.0.0.1:8012/v1/rerank \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "some-model",
|
||||||
|
"query": "What is panda?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": [
|
||||||
|
"hi",
|
||||||
|
"it is a bear",
|
||||||
|
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
|
||||||
|
]
|
||||||
|
}' | jq
|
||||||
|
```
|
||||||
|
|
||||||
### POST `/infill`: For code infilling.
|
### POST `/infill`: For code infilling.
|
||||||
|
|
||||||
Takes a prefix and a suffix and returns the predicted completion as stream.
|
Takes a prefix and a suffix and returns the predicted completion as stream.
|
||||||
|
@ -92,6 +92,7 @@ enum server_task_type {
|
|||||||
enum server_task_cmpl_type {
|
enum server_task_cmpl_type {
|
||||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
SERVER_TASK_CMPL_TYPE_NORMAL,
|
||||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
||||||
|
SERVER_TASK_CMPL_TYPE_RERANK,
|
||||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
SERVER_TASK_CMPL_TYPE_INFILL,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -172,6 +173,7 @@ struct server_slot {
|
|||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
bool stopped_eos = false;
|
bool stopped_eos = false;
|
||||||
@ -954,8 +956,17 @@ struct server_context {
|
|||||||
slot.prompt = *prompt;
|
slot.prompt = *prompt;
|
||||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||||
slot.prompt = prompt->at(0);
|
slot.prompt = prompt->at(0);
|
||||||
|
} else if (prompt->is_array() && prompt->size() > 1) {
|
||||||
|
// array of strings
|
||||||
|
for (const auto & el : *prompt) {
|
||||||
|
if (!el.is_string()) {
|
||||||
|
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slot.prompt = *prompt;
|
||||||
} else {
|
} else {
|
||||||
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1389,6 +1400,7 @@ struct server_context {
|
|||||||
|
|
||||||
res.data = json {
|
res.data = json {
|
||||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||||
|
{"index", slot.index},
|
||||||
};
|
};
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
@ -1407,6 +1419,44 @@ struct server_context {
|
|||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
||||||
|
server_task_result res;
|
||||||
|
res.id = slot.id_task;
|
||||||
|
res.error = false;
|
||||||
|
res.stop = true;
|
||||||
|
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
|
if (embd == NULL) {
|
||||||
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embd == NULL) {
|
||||||
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||||
|
|
||||||
|
res.data = json {
|
||||||
|
{"index", slot.index},
|
||||||
|
{"score", -1e6},
|
||||||
|
};
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
res.data = json {
|
||||||
|
{"index", slot.index},
|
||||||
|
{"score", embd[0]},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
|
||||||
|
|
||||||
|
queue_results.send(res);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Functions to create new task(s) and receive result(s)
|
// Functions to create new task(s) and receive result(s)
|
||||||
//
|
//
|
||||||
@ -1442,6 +1492,19 @@ struct server_context {
|
|||||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
||||||
else if (prompt.is_array()) {
|
else if (prompt.is_array()) {
|
||||||
std::vector<json> prompts = prompt;
|
std::vector<json> prompts = prompt;
|
||||||
|
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
|
// prompts[0] is the question
|
||||||
|
// the rest are the answers/documents
|
||||||
|
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
||||||
|
for (size_t i = 1; i < prompts.size(); i++) {
|
||||||
|
json qd;
|
||||||
|
qd.push_back(prompts[0]);
|
||||||
|
qd.push_back(prompts[i]);
|
||||||
|
data["index"] = i - 1;
|
||||||
|
create_task(data, true, qd);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
||||||
for (size_t i = 0; i < prompts.size(); i++) {
|
for (size_t i = 0; i < prompts.size(); i++) {
|
||||||
const auto & e = prompts[i];
|
const auto & e = prompts[i];
|
||||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||||
@ -1452,6 +1515,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// invalid case
|
// invalid case
|
||||||
else {
|
else {
|
||||||
throw std::runtime_error(error_msg);
|
throw std::runtime_error(error_msg);
|
||||||
@ -1492,7 +1556,9 @@ struct server_context {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t idx = result.data["index"];
|
const size_t idx = result.data["index"];
|
||||||
|
GGML_ASSERT(idx < results.size() && "index out of range");
|
||||||
|
|
||||||
results[idx] = result;
|
results[idx] = result;
|
||||||
}
|
}
|
||||||
result_handler(results);
|
result_handler(results);
|
||||||
@ -1903,6 +1969,7 @@ struct server_context {
|
|||||||
// track if this is an embedding or non-embedding batch
|
// track if this is an embedding or non-embedding batch
|
||||||
// if we've added sampled tokens above, we are in non-embedding mode
|
// if we've added sampled tokens above, we are in non-embedding mode
|
||||||
// -1: none, 0: non-embedding, 1: embedding
|
// -1: none, 0: non-embedding, 1: embedding
|
||||||
|
// TODO: make enum
|
||||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
@ -1951,6 +2018,29 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
prompt_tokens = embd_inp;
|
prompt_tokens = embd_inp;
|
||||||
|
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
|
// require slot.prompt to be array of 2 strings
|
||||||
|
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
||||||
|
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// prompt: <s>query</s><s>doc</s>
|
||||||
|
prompt_tokens.clear();
|
||||||
|
prompt_tokens.push_back(llama_token_bos(model));
|
||||||
|
{
|
||||||
|
const auto part = tokenize(slot.prompt[0], false);
|
||||||
|
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||||
|
}
|
||||||
|
prompt_tokens.push_back(llama_token_eos(model));
|
||||||
|
prompt_tokens.push_back(llama_token_bos(model));
|
||||||
|
{
|
||||||
|
const auto part = tokenize(slot.prompt[1], false);
|
||||||
|
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||||
|
}
|
||||||
|
prompt_tokens.push_back(llama_token_eos(model));
|
||||||
} else {
|
} else {
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||||
}
|
}
|
||||||
@ -1970,7 +2060,7 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
// this prompt is too large to process - discard it
|
// this prompt is too large to process - discard it
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
slot.release();
|
slot.release();
|
||||||
@ -2048,7 +2138,8 @@ struct server_context {
|
|||||||
slot.n_prompt_tokens_processed = 0;
|
slot.n_prompt_tokens_processed = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
@ -2056,7 +2147,10 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
|
const bool slot_type =
|
||||||
|
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||||
|
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||||
|
|
||||||
if (batch_type == -1) {
|
if (batch_type == -1) {
|
||||||
batch_type = slot_type;
|
batch_type = slot_type;
|
||||||
} else if (batch_type != slot_type) {
|
} else if (batch_type != slot_type) {
|
||||||
@ -2229,6 +2323,13 @@ struct server_context {
|
|||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
|
send_rerank(slot, batch_view);
|
||||||
|
slot.release();
|
||||||
|
slot.i_batch = -1;
|
||||||
|
continue; // continue loop of slots
|
||||||
|
}
|
||||||
|
|
||||||
// prompt evaluated for next-token prediction
|
// prompt evaluated for next-token prediction
|
||||||
slot.state = SLOT_STATE_GENERATING;
|
slot.state = SLOT_STATE_GENERATING;
|
||||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
@ -2787,8 +2888,8 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2848,8 +2949,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
// TODO: maybe merge this function with "handle_completions_generic"
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2973,6 +3074,11 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
// TODO: somehow clean up this checks in the future
|
||||||
|
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
|
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool is_openai = false;
|
bool is_openai = false;
|
||||||
|
|
||||||
@ -3023,6 +3129,79 @@ int main(int argc, char ** argv) {
|
|||||||
res_ok(res, root);
|
res_ok(res, root);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (!ctx_server.params.reranking) {
|
||||||
|
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
//int top_n = 1;
|
||||||
|
//if (body.count("top_n") != 1) {
|
||||||
|
// top_n = body.at("top_n");
|
||||||
|
//} else {
|
||||||
|
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
// return;
|
||||||
|
//}
|
||||||
|
|
||||||
|
json query;
|
||||||
|
if (body.count("query") == 1) {
|
||||||
|
query = body.at("query");
|
||||||
|
if (!query.is_string()) {
|
||||||
|
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
||||||
|
if (documents.empty()) {
|
||||||
|
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// construct prompt object: array of ["query", "doc0", "doc1", ...]
|
||||||
|
json prompt;
|
||||||
|
prompt.push_back(query);
|
||||||
|
for (const auto & doc : documents) {
|
||||||
|
prompt.push_back(doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
|
||||||
|
|
||||||
|
// create and queue the task
|
||||||
|
json responses = json::array();
|
||||||
|
bool error = false;
|
||||||
|
{
|
||||||
|
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
||||||
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
// get the result
|
||||||
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
||||||
|
|
||||||
|
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||||
|
for (const auto & res : results) {
|
||||||
|
responses.push_back(res.data);
|
||||||
|
}
|
||||||
|
}, [&](const json & error_data) {
|
||||||
|
res_error(res, error_data);
|
||||||
|
error = true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// write JSON response
|
||||||
|
json root = format_response_rerank(body, responses);
|
||||||
|
res_ok(res, root);
|
||||||
|
};
|
||||||
|
|
||||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||||
json result = json::array();
|
json result = json::array();
|
||||||
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
||||||
@ -3119,6 +3298,10 @@ int main(int argc, char ** argv) {
|
|||||||
svr->Post("/embedding", handle_embeddings); // legacy
|
svr->Post("/embedding", handle_embeddings); // legacy
|
||||||
svr->Post("/embeddings", handle_embeddings);
|
svr->Post("/embeddings", handle_embeddings);
|
||||||
svr->Post("/v1/embeddings", handle_embeddings);
|
svr->Post("/v1/embeddings", handle_embeddings);
|
||||||
|
svr->Post("/rerank", handle_rerank);
|
||||||
|
svr->Post("/reranking", handle_rerank);
|
||||||
|
svr->Post("/v1/rerank", handle_rerank);
|
||||||
|
svr->Post("/v1/reranking", handle_rerank);
|
||||||
svr->Post("/tokenize", handle_tokenize);
|
svr->Post("/tokenize", handle_tokenize);
|
||||||
svr->Post("/detokenize", handle_detokenize);
|
svr->Post("/detokenize", handle_detokenize);
|
||||||
// LoRA adapters hotswap
|
// LoRA adapters hotswap
|
||||||
|
@ -15,7 +15,7 @@ Feature: llama.cpp server
|
|||||||
And 128 as batch size
|
And 128 as batch size
|
||||||
And 128 as ubatch size
|
And 128 as ubatch size
|
||||||
And 512 KV cache size
|
And 512 KV cache size
|
||||||
And embeddings extraction
|
And enable embeddings endpoint
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
|
|
||||||
|
42
examples/server/tests/features/rerank.feature
Normal file
42
examples/server/tests/features/rerank.feature
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
@llama.cpp
|
||||||
|
@rerank
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf
|
||||||
|
And a model file jina-reranker-v1-tiny-en.gguf
|
||||||
|
And a model alias jina-reranker-v1-tiny-en
|
||||||
|
And 42 as server seed
|
||||||
|
And 2 slots
|
||||||
|
And 512 as batch size
|
||||||
|
And 512 as ubatch size
|
||||||
|
And 512 KV cache size
|
||||||
|
And enable reranking endpoint
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario: Rerank
|
||||||
|
Given a rerank query:
|
||||||
|
"""
|
||||||
|
Machine learning is
|
||||||
|
"""
|
||||||
|
And a rerank document:
|
||||||
|
"""
|
||||||
|
A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.
|
||||||
|
"""
|
||||||
|
And a rerank document:
|
||||||
|
"""
|
||||||
|
Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.
|
||||||
|
"""
|
||||||
|
And a rerank document:
|
||||||
|
"""
|
||||||
|
Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.
|
||||||
|
"""
|
||||||
|
And a rerank document:
|
||||||
|
"""
|
||||||
|
Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.
|
||||||
|
"""
|
||||||
|
When reranking request
|
||||||
|
Then reranking results are returned
|
||||||
|
Then reranking highest score is index 2 and lowest score is index 3
|
@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||||||
context.server_api_key = None
|
context.server_api_key = None
|
||||||
context.server_continuous_batching = False
|
context.server_continuous_batching = False
|
||||||
context.server_embeddings = False
|
context.server_embeddings = False
|
||||||
|
context.server_reranking = False
|
||||||
context.server_metrics = False
|
context.server_metrics = False
|
||||||
context.server_process = None
|
context.server_process = None
|
||||||
context.seed = None
|
context.seed = None
|
||||||
@ -83,6 +84,10 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
context.prompts = []
|
context.prompts = []
|
||||||
|
|
||||||
|
context.reranking_query = None
|
||||||
|
context.reranking_documents = []
|
||||||
|
context.reranking_results = None
|
||||||
|
|
||||||
|
|
||||||
@step('a model file {hf_file} from HF repo {hf_repo}')
|
@step('a model file {hf_file} from HF repo {hf_repo}')
|
||||||
def step_download_hf_model(context, hf_file: str, hf_repo: str):
|
def step_download_hf_model(context, hf_file: str, hf_repo: str):
|
||||||
@ -172,10 +177,13 @@ def step_server_continuous_batching(context):
|
|||||||
context.server_continuous_batching = True
|
context.server_continuous_batching = True
|
||||||
|
|
||||||
|
|
||||||
@step('embeddings extraction')
|
@step('enable embeddings endpoint')
|
||||||
def step_server_embeddings(context):
|
def step_server_embeddings(context):
|
||||||
context.server_embeddings = True
|
context.server_embeddings = True
|
||||||
|
|
||||||
|
@step('enable reranking endpoint')
|
||||||
|
def step_server_reranking(context):
|
||||||
|
context.server_reranking = True
|
||||||
|
|
||||||
@step('prometheus compatible metrics exposed')
|
@step('prometheus compatible metrics exposed')
|
||||||
def step_server_metrics(context):
|
def step_server_metrics(context):
|
||||||
@ -452,6 +460,14 @@ def step_impl(context, n_ga_w):
|
|||||||
def step_prompt_passkey(context):
|
def step_prompt_passkey(context):
|
||||||
context.prompt_passkey = context_text(context)
|
context.prompt_passkey = context_text(context)
|
||||||
|
|
||||||
|
@step('a rerank query')
|
||||||
|
def step_set_rerank_query(context):
|
||||||
|
context.reranking_query = context_text(context)
|
||||||
|
context.reranking_documents = []
|
||||||
|
|
||||||
|
@step('a rerank document')
|
||||||
|
def step_set_rerank_document(context):
|
||||||
|
context.reranking_documents.append(context_text(context))
|
||||||
|
|
||||||
@step('{n_prompts:d} fixed prompts')
|
@step('{n_prompts:d} fixed prompts')
|
||||||
def step_fixed_prompts(context, n_prompts):
|
def step_fixed_prompts(context, n_prompts):
|
||||||
@ -619,6 +635,22 @@ async def step_compute_embedding(context):
|
|||||||
context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
|
context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
|
||||||
|
|
||||||
|
|
||||||
|
@step('reranking request')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_compute_reranking(context):
|
||||||
|
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||||
|
async with session.post(f'{context.base_url}/reranking',
|
||||||
|
json={
|
||||||
|
"query": context.reranking_query,
|
||||||
|
"documents": context.reranking_documents,
|
||||||
|
}) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
response_json = await response.json()
|
||||||
|
context.reranking_results = response_json['results']
|
||||||
|
else:
|
||||||
|
context.reranking_results = response.status
|
||||||
|
|
||||||
|
|
||||||
@step('all embeddings are the same')
|
@step('all embeddings are the same')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_all_embeddings_are_the_same(context):
|
async def step_all_embeddings_are_the_same(context):
|
||||||
@ -704,6 +736,24 @@ async def all_embeddings_are_generated(context):
|
|||||||
for i in range(n_embedding_requests):
|
for i in range(n_embedding_requests):
|
||||||
assert_embeddings(context.tasks_result.pop().pop())
|
assert_embeddings(context.tasks_result.pop().pop())
|
||||||
|
|
||||||
|
@step('reranking results are returned')
|
||||||
|
def reranking_results_are_returned(context):
|
||||||
|
assert len(context.reranking_results) == len(context.reranking_documents)
|
||||||
|
|
||||||
|
@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}')
|
||||||
|
def reranking_results_are_returned(context, idx_high: int, idx_low: int):
|
||||||
|
max_score, max_idx = 0, 0
|
||||||
|
min_score, min_idx = 0, 0
|
||||||
|
for res in context.reranking_results:
|
||||||
|
if max_score < res['relevance_score']:
|
||||||
|
max_score = res['relevance_score']
|
||||||
|
max_idx = res['index']
|
||||||
|
if min_score > res['relevance_score']:
|
||||||
|
min_score = res['relevance_score']
|
||||||
|
min_idx = res['index']
|
||||||
|
print(context.reranking_results)
|
||||||
|
assert max_idx == idx_high
|
||||||
|
assert min_idx == idx_low
|
||||||
|
|
||||||
@step('adding special tokens')
|
@step('adding special tokens')
|
||||||
def step_tokenize_set_add_special(context):
|
def step_tokenize_set_add_special(context):
|
||||||
@ -1362,6 +1412,8 @@ def start_server_background(context):
|
|||||||
server_args.append('--cont-batching')
|
server_args.append('--cont-batching')
|
||||||
if context.server_embeddings:
|
if context.server_embeddings:
|
||||||
server_args.append('--embedding')
|
server_args.append('--embedding')
|
||||||
|
if context.server_reranking:
|
||||||
|
server_args.append('--reranking')
|
||||||
if context.server_metrics:
|
if context.server_metrics:
|
||||||
server_args.append('--metrics')
|
server_args.append('--metrics')
|
||||||
if context.model_alias:
|
if context.model_alias:
|
||||||
|
@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
json res = json {
|
json res = json {
|
||||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
{"object", "list"},
|
{"object", "list"},
|
||||||
{"usage", json {
|
{"usage", json { // TODO: fill
|
||||||
{"prompt_tokens", 0},
|
{"prompt_tokens", 0},
|
||||||
{"total_tokens", 0}
|
{"total_tokens", 0}
|
||||||
}},
|
}},
|
||||||
@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static json format_response_rerank(const json & request, const json & ranks) {
|
||||||
|
json data = json::array();
|
||||||
|
int i = 0;
|
||||||
|
for (const auto & rank : ranks) {
|
||||||
|
data.push_back(json{
|
||||||
|
{"index", i++},
|
||||||
|
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
json res = json {
|
||||||
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
|
{"object", "list"},
|
||||||
|
{"usage", json { // TODO: fill
|
||||||
|
{"prompt_tokens", 0},
|
||||||
|
{"total_tokens", 0}
|
||||||
|
}},
|
||||||
|
{"results", data}
|
||||||
|
};
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
static bool is_valid_utf8(const std::string & str) {
|
static bool is_valid_utf8(const std::string & str) {
|
||||||
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
||||||
const unsigned char* end = bytes + str.length();
|
const unsigned char* end = bytes + str.length();
|
||||||
|
@ -345,6 +345,8 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
ENC_FFN_DOWN = auto()
|
ENC_FFN_DOWN = auto()
|
||||||
ENC_FFN_UP = auto()
|
ENC_FFN_UP = auto()
|
||||||
ENC_OUTPUT_NORM = auto()
|
ENC_OUTPUT_NORM = auto()
|
||||||
|
CLS = auto() # classifier
|
||||||
|
CLS_OUT = auto() # classifier output projection
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
@ -504,6 +506,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
|
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
|
||||||
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
|
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
|
||||||
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
||||||
|
MODEL_TENSOR.CLS: "cls",
|
||||||
|
MODEL_TENSOR.CLS_OUT: "cls.output",
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
@ -613,6 +617,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||||
|
MODEL_TENSOR.CLS,
|
||||||
|
MODEL_TENSOR.CLS_OUT,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.NOMIC_BERT: [
|
MODEL_ARCH.NOMIC_BERT: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
@ -644,6 +650,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_GATE,
|
MODEL_TENSOR.FFN_GATE,
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||||
|
MODEL_TENSOR.CLS,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.MPT: [
|
MODEL_ARCH.MPT: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
@ -679,6 +679,15 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||||
"encoder.final_layer_norm", # t5
|
"encoder.final_layer_norm", # t5
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.CLS: (
|
||||||
|
"classifier", # jina
|
||||||
|
"classifier.dense", # roberta
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.CLS_OUT: (
|
||||||
|
"classifier.out_proj", # roberta
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# architecture-specific block mappings
|
# architecture-specific block mappings
|
||||||
|
@ -193,6 +193,7 @@ extern "C" {
|
|||||||
LLAMA_POOLING_TYPE_MEAN = 1,
|
LLAMA_POOLING_TYPE_MEAN = 1,
|
||||||
LLAMA_POOLING_TYPE_CLS = 2,
|
LLAMA_POOLING_TYPE_CLS = 2,
|
||||||
LLAMA_POOLING_TYPE_LAST = 3,
|
LLAMA_POOLING_TYPE_LAST = 3,
|
||||||
|
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_attention_type {
|
enum llama_attention_type {
|
||||||
@ -872,7 +873,8 @@ extern "C" {
|
|||||||
|
|
||||||
// Get the embeddings for a sequence id
|
// Get the embeddings for a sequence id
|
||||||
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||||
// shape: [n_embd] (1-dimensional)
|
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
|
||||||
|
// otherwise: float[n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -1554,7 +1554,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_UGM:
|
case LLAMA_VOCAB_TYPE_UGM:
|
||||||
{
|
{
|
||||||
if (add_special && vocab.tokenizer_add_bos != 0) {
|
if (add_special && vocab.tokenizer_add_bos) {
|
||||||
GGML_ASSERT(vocab.special_bos_id != -1);
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
output.push_back(vocab.special_bos_id);
|
output.push_back(vocab.special_bos_id);
|
||||||
}
|
}
|
||||||
@ -1572,14 +1572,14 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
||||||
LLAMA_LOG_WARN(
|
LLAMA_LOG_WARN(
|
||||||
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
||||||
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
||||||
"Are you sure this is what you want?\n", __FUNCTION__);
|
"Are you sure this is what you want?\n", __FUNCTION__);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.tokenizer_add_eos == 1) {
|
if (add_special && vocab.tokenizer_add_eos) {
|
||||||
GGML_ASSERT(vocab.special_eos_id != -1);
|
GGML_ASSERT(vocab.special_eos_id != -1);
|
||||||
output.push_back(vocab.special_eos_id);
|
output.push_back(vocab.special_eos_id);
|
||||||
}
|
}
|
||||||
@ -1791,11 +1791,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||||||
// suppressing them like CONTROL tokens.
|
// suppressing them like CONTROL tokens.
|
||||||
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
|
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
|
||||||
return _try_copy(token_text.data(), token_text.size());
|
return _try_copy(token_text.data(), token_text.size());
|
||||||
} else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
}
|
||||||
|
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
||||||
std::string result = token_text;
|
std::string result = token_text;
|
||||||
llama_unescape_whitespace(result);
|
llama_unescape_whitespace(result);
|
||||||
return _try_copy(result.data(), result.size());
|
return _try_copy(result.data(), result.size());
|
||||||
} else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
|
}
|
||||||
|
if (attr & LLAMA_TOKEN_ATTR_BYTE) {
|
||||||
char byte = (char) llama_token_to_byte(vocab, token);
|
char byte = (char) llama_token_to_byte(vocab, token);
|
||||||
return _try_copy((char*) &byte, 1);
|
return _try_copy((char*) &byte, 1);
|
||||||
}
|
}
|
||||||
@ -1806,7 +1808,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||||||
// suppressing them like CONTROL tokens.
|
// suppressing them like CONTROL tokens.
|
||||||
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
|
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
|
||||||
return _try_copy(token_text.data(), token_text.size());
|
return _try_copy(token_text.data(), token_text.size());
|
||||||
} else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
}
|
||||||
|
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
||||||
std::string result = llama_decode_text(token_text);
|
std::string result = llama_decode_text(token_text);
|
||||||
return _try_copy(result.data(), result.size());
|
return _try_copy(result.data(), result.size());
|
||||||
}
|
}
|
||||||
|
@ -606,6 +606,8 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_ENC_FFN_DOWN,
|
LLM_TENSOR_ENC_FFN_DOWN,
|
||||||
LLM_TENSOR_ENC_FFN_UP,
|
LLM_TENSOR_ENC_FFN_UP,
|
||||||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||||
|
LLM_TENSOR_CLS,
|
||||||
|
LLM_TENSOR_CLS_OUT,
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||||
@ -793,6 +795,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_CLS, "cls" },
|
||||||
|
{ LLM_TENSOR_CLS_OUT, "cls.output" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -828,6 +832,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_CLS, "cls" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -2894,6 +2899,7 @@ struct llama_model {
|
|||||||
llama_hparams hparams = {};
|
llama_hparams hparams = {};
|
||||||
llama_vocab vocab;
|
llama_vocab vocab;
|
||||||
|
|
||||||
|
// TODO: should init all tensors to nullptr
|
||||||
struct ggml_tensor * tok_embd;
|
struct ggml_tensor * tok_embd;
|
||||||
struct ggml_tensor * type_embd;
|
struct ggml_tensor * type_embd;
|
||||||
struct ggml_tensor * pos_embd;
|
struct ggml_tensor * pos_embd;
|
||||||
@ -2906,6 +2912,12 @@ struct llama_model {
|
|||||||
struct ggml_tensor * output_b;
|
struct ggml_tensor * output_b;
|
||||||
struct ggml_tensor * output_norm_enc;
|
struct ggml_tensor * output_norm_enc;
|
||||||
|
|
||||||
|
// classifier
|
||||||
|
struct ggml_tensor * cls;
|
||||||
|
struct ggml_tensor * cls_b;
|
||||||
|
struct ggml_tensor * cls_out = nullptr;
|
||||||
|
struct ggml_tensor * cls_out_b = nullptr;
|
||||||
|
|
||||||
std::vector<llama_layer> layers;
|
std::vector<llama_layer> layers;
|
||||||
|
|
||||||
llama_split_mode split_mode;
|
llama_split_mode split_mode;
|
||||||
@ -5604,7 +5616,7 @@ static void llm_load_hparams(
|
|||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||||
hparams.f_max_alibi_bias = 8.0f;
|
hparams.f_max_alibi_bias = 8.0f;
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
@ -6313,6 +6325,7 @@ static void llm_load_vocab(
|
|||||||
tokenizer_pre == "phi-2" ||
|
tokenizer_pre == "phi-2" ||
|
||||||
tokenizer_pre == "jina-es" ||
|
tokenizer_pre == "jina-es" ||
|
||||||
tokenizer_pre == "jina-de" ||
|
tokenizer_pre == "jina-de" ||
|
||||||
|
tokenizer_pre == "jina-v1-en" ||
|
||||||
tokenizer_pre == "jina-v2-es" ||
|
tokenizer_pre == "jina-v2-es" ||
|
||||||
tokenizer_pre == "jina-v2-de" ||
|
tokenizer_pre == "jina-v2-de" ||
|
||||||
tokenizer_pre == "jina-v2-code") {
|
tokenizer_pre == "jina-v2-code") {
|
||||||
@ -6439,7 +6452,12 @@ static void llm_load_vocab(
|
|||||||
|
|
||||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||||
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
||||||
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
|
||||||
|
//GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
||||||
|
if (word.empty()) {
|
||||||
|
LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i);
|
||||||
|
word = "[EMPTY_" + std::to_string(i) + "]";
|
||||||
|
}
|
||||||
|
|
||||||
vocab.token_to_id[word] = i;
|
vocab.token_to_id[word] = i;
|
||||||
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
|
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
|
||||||
@ -6520,9 +6538,15 @@ static void llm_load_vocab(
|
|||||||
vocab.linefeed_id = ids[0];
|
vocab.linefeed_id = ids[0];
|
||||||
} else {
|
} else {
|
||||||
const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
|
const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
|
||||||
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
|
||||||
|
//GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
||||||
|
if (ids.empty()) {
|
||||||
|
LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
|
||||||
|
vocab.linefeed_id = vocab.special_pad_id;
|
||||||
|
} else {
|
||||||
vocab.linefeed_id = ids[0];
|
vocab.linefeed_id = ids[0];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// special tokens
|
// special tokens
|
||||||
{
|
{
|
||||||
@ -7394,6 +7418,12 @@ static bool llm_load_tensors(
|
|||||||
|
|
||||||
if (model.arch == LLM_ARCH_BERT) {
|
if (model.arch == LLM_ARCH_BERT) {
|
||||||
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train});
|
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train});
|
||||||
|
|
||||||
|
model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
model.cls_out = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
}
|
}
|
||||||
|
|
||||||
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
||||||
@ -7446,6 +7476,8 @@ static bool llm_load_tensors(
|
|||||||
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
|
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
|
||||||
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias
|
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias
|
||||||
|
|
||||||
|
model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
for (int i = 0; i < n_layer; ++i) {
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
@ -10279,6 +10311,10 @@ struct llm_build_context {
|
|||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
switch (pooling_type) {
|
switch (pooling_type) {
|
||||||
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
|
{
|
||||||
|
cur = inp;
|
||||||
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
{
|
{
|
||||||
struct ggml_tensor * inp_mean = build_inp_mean();
|
struct ggml_tensor * inp_mean = build_inp_mean();
|
||||||
@ -10290,9 +10326,26 @@ struct llm_build_context {
|
|||||||
struct ggml_tensor * inp_cls = build_inp_cls();
|
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||||
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
{
|
{
|
||||||
cur = inp;
|
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||||
|
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
||||||
|
|
||||||
|
// classification head
|
||||||
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||||
|
GGML_ASSERT(model.cls != nullptr);
|
||||||
|
GGML_ASSERT(model.cls_b != nullptr);
|
||||||
|
|
||||||
|
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
|
||||||
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||||
|
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
||||||
|
if (model.cls_out) {
|
||||||
|
GGML_ASSERT(model.cls_out_b != nullptr);
|
||||||
|
|
||||||
|
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
@ -11521,8 +11574,8 @@ struct llm_build_context {
|
|||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// final output
|
|
||||||
cur = inpL;
|
cur = inpL;
|
||||||
|
|
||||||
cb(cur, "result_embd", -1);
|
cb(cur, "result_embd", -1);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
@ -16682,7 +16735,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
if (cparams.embeddings && (
|
||||||
|
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
||||||
|
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
||||||
const int64_t n_seqs = batch.n_seqs;
|
const int64_t n_seqs = batch.n_seqs;
|
||||||
@ -16697,7 +16752,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
|
|||||||
const llama_seq_id seq_id = batch.seq_id[s][0];
|
const llama_seq_id seq_id = batch.seq_id[s][0];
|
||||||
|
|
||||||
// TODO: adapt limits to n_seqs when batch.equal_seqs is true
|
// TODO: adapt limits to n_seqs when batch.equal_seqs is true
|
||||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
|
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
||||||
|
|
||||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||||
const llama_pos pos = batch.pos[s*n_seq_tokens + i];
|
const llama_pos pos = batch.pos[s*n_seq_tokens + i];
|
||||||
@ -17237,6 +17292,20 @@ static int llama_decode_internal(
|
|||||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
|
{
|
||||||
|
// extract the rerank score - a single float per sequence
|
||||||
|
auto & embd_seq_out = lctx.embd_seq;
|
||||||
|
|
||||||
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||||
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||||
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
embd_seq_out[seq_id].resize(1);
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
{
|
{
|
||||||
GGML_ABORT("unknown pooling type");
|
GGML_ABORT("unknown pooling type");
|
||||||
@ -17443,6 +17512,13 @@ static int llama_encode_internal(
|
|||||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
|
{
|
||||||
|
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
||||||
|
// wait for an encoder model that requires this pooling type in order to test it
|
||||||
|
// https://github.com/ggerganov/llama.cpp/pull/9510
|
||||||
|
GGML_ABORT("RANK pooling not implemented yet");
|
||||||
|
}
|
||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
{
|
{
|
||||||
GGML_ABORT("unknown pooling type");
|
GGML_ABORT("unknown pooling type");
|
||||||
|
Loading…
Reference in New Issue
Block a user