mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 06:19:02 +01:00
llama : add pooling switch
This commit is contained in:
parent
9bbeb0f110
commit
e66da356a4
43
llama.cpp
43
llama.cpp
@ -8113,7 +8113,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
const llama_pos pos = batch.pos[i];
|
||||
const llama_pos pos = batch.pos[i];
|
||||
if (pos == 0) {
|
||||
data[seq_id] = i;
|
||||
}
|
||||
@ -8379,10 +8379,17 @@ static int llama_decode_internal(
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
||||
} else {
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||
switch (hparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
||||
break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown pooling type");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -8680,19 +8687,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
||||
const auto& token_data = vocab.id_to_token.at(id);
|
||||
switch (llama_vocab_get_type(vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM: {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM: {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user