mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
Cleanup STL headers + fix embedding examples + minor stuff
This commit is contained in:
parent
55ad42af84
commit
03f7e33560
@ -1,15 +1,6 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cinttypes>
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstring>
|
|
||||||
#include <fstream>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
params.model = "models/llama-7B/ggml-model.bin";
|
params.model = "models/llama-7B/ggml-model.bin";
|
||||||
@ -94,9 +85,13 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int n_embd = llama_n_embd(ctx);
|
||||||
const auto embeddings = llama_get_embeddings(ctx);
|
const auto embeddings = llama_get_embeddings(ctx);
|
||||||
|
|
||||||
// TODO: print / use the embeddings
|
for (int i = 0; i < n_embd; i++) {
|
||||||
|
printf("%f ", embeddings[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
|
@ -1,14 +1,6 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cinttypes>
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstring>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
std::vector<double> softmax(const std::vector<float>& logits) {
|
std::vector<double> softmax(const std::vector<float>& logits) {
|
||||||
std::vector<double> probs(logits.size());
|
std::vector<double> probs(logits.size());
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
|
22
llama.cpp
22
llama.cpp
@ -1261,10 +1261,10 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
|||||||
double repeat_penalty) {
|
double repeat_penalty) {
|
||||||
auto & rng = lctx.rng;
|
auto & rng = lctx.rng;
|
||||||
|
|
||||||
const auto & vocab = lctx.vocab;
|
const int n_logits = lctx.model.hparams.n_vocab;
|
||||||
const auto & logits = lctx.logits;
|
|
||||||
|
|
||||||
int n_logits = vocab.id_to_token.size();
|
const auto & logits = lctx.logits;
|
||||||
|
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||||
|
|
||||||
std::vector<std::pair<double, llama_vocab::id>> logits_id;
|
std::vector<std::pair<double, llama_vocab::id>> logits_id;
|
||||||
logits_id.reserve(n_logits);
|
logits_id.reserve(n_logits);
|
||||||
@ -1276,13 +1276,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
|||||||
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||||
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||||
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
if (logits[i] < 0.0) {
|
if (plogits[i] < 0.0) {
|
||||||
logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
|
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||||
} else {
|
} else {
|
||||||
logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
|
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1677,6 +1677,8 @@ struct llama_context * llama_init_from_file(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
|
// resized during inference
|
||||||
if (params.logits_all) {
|
if (params.logits_all) {
|
||||||
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
||||||
} else {
|
} else {
|
||||||
@ -1684,7 +1686,7 @@ struct llama_context * llama_init_from_file(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (params.embedding){
|
if (params.embedding){
|
||||||
ctx->embedding.reserve(hparams.n_embd);
|
ctx->embedding.resize(hparams.n_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
|
ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
|
||||||
@ -1761,6 +1763,10 @@ int llama_n_ctx(struct llama_context * ctx) {
|
|||||||
return ctx->model.hparams.n_ctx;
|
return ctx->model.hparams.n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int llama_n_embd(struct llama_context * ctx) {
|
||||||
|
return ctx->model.hparams.n_embd;
|
||||||
|
}
|
||||||
|
|
||||||
float * llama_get_logits(struct llama_context * ctx) {
|
float * llama_get_logits(struct llama_context * ctx) {
|
||||||
return ctx->logits.data();
|
return ctx->logits.data();
|
||||||
}
|
}
|
||||||
|
1
llama.h
1
llama.h
@ -109,6 +109,7 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
|
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
|
||||||
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
|
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
|
||||||
|
LLAMA_API int llama_n_embd (struct llama_context * ctx);
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_eval()
|
// Token logits obtained from the last call to llama_eval()
|
||||||
// The logits for the last token are stored in the last row
|
// The logits for the last token are stored in the last row
|
||||||
|
Loading…
x
Reference in New Issue
Block a user