2023-04-11 06:19:54 -07:00
|
|
|
// Defines sigaction on msys:
|
|
|
|
#ifndef _GNU_SOURCE
|
|
|
|
#define _GNU_SOURCE
|
|
|
|
#endif
|
|
|
|
|
2023-03-25 20:26:40 +02:00
|
|
|
#include "common.h"
|
2023-03-22 07:32:36 +02:00
|
|
|
#include "llama.h"
|
2023-03-10 20:40:58 +02:00
|
|
|
|
|
|
|
#include <cassert>
|
2023-03-20 03:17:23 -07:00
|
|
|
#include <cinttypes>
|
2023-03-10 20:40:58 +02:00
|
|
|
#include <cmath>
|
|
|
|
#include <cstdio>
|
|
|
|
#include <cstring>
|
2023-04-16 12:13:00 +02:00
|
|
|
#include <ctime>
|
2023-03-10 20:40:58 +02:00
|
|
|
#include <fstream>
|
2023-03-19 13:44:30 -06:00
|
|
|
#include <iostream>
|
2023-03-10 20:40:58 +02:00
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
|
2023-03-13 04:08:01 +01:00
|
|
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
2023-03-12 22:13:28 +01:00
|
|
|
#include <signal.h>
|
|
|
|
#include <unistd.h>
|
2023-03-15 13:56:24 -06:00
|
|
|
#elif defined (_WIN32)
|
|
|
|
#include <signal.h>
|
2023-03-13 04:08:01 +01:00
|
|
|
#endif
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-28 17:09:55 +03:00
|
|
|
static console_state con_st;
|
2023-04-22 16:56:35 +08:00
|
|
|
static llama_context ** g_ctx;
|
2023-03-21 18:11:01 +01:00
|
|
|
|
2023-03-12 22:13:28 +01:00
|
|
|
static bool is_interacting = false;
|
|
|
|
|
2023-03-15 13:56:24 -06:00
|
|
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
2023-03-12 22:13:28 +01:00
|
|
|
void sigint_handler(int signo) {
|
2023-03-28 17:09:55 +03:00
|
|
|
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
2023-04-14 21:58:43 +02:00
|
|
|
printf("\n"); // this also force flush stdout.
|
2023-03-12 22:13:28 +01:00
|
|
|
if (signo == SIGINT) {
|
|
|
|
if (!is_interacting) {
|
|
|
|
is_interacting=true;
|
|
|
|
} else {
|
2023-04-22 16:56:35 +08:00
|
|
|
llama_print_timings(*g_ctx);
|
2023-03-12 22:13:28 +01:00
|
|
|
_exit(130);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2023-03-13 04:08:01 +01:00
|
|
|
#endif
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-10 20:40:58 +02:00
|
|
|
int main(int argc, char ** argv) {
|
|
|
|
gpt_params params;
|
|
|
|
params.model = "models/llama-7B/ggml-model.bin";
|
|
|
|
|
|
|
|
if (gpt_params_parse(argc, argv, params) == false) {
|
|
|
|
return 1;
|
|
|
|
}
|
2023-03-19 18:37:02 +02:00
|
|
|
|
2023-03-25 22:29:22 +02:00
|
|
|
// save choice to use color for later
|
|
|
|
// (note for later: this is a slightly awkward choice)
|
2023-03-28 17:09:55 +03:00
|
|
|
con_st.use_color = params.use_color;
|
2023-03-25 22:29:22 +02:00
|
|
|
|
|
|
|
#if defined (_WIN32)
|
2023-03-28 17:09:55 +03:00
|
|
|
win32_console_init(params.use_color);
|
2023-03-25 22:29:22 +02:00
|
|
|
#endif
|
|
|
|
|
2023-03-25 20:26:40 +02:00
|
|
|
if (params.perplexity) {
|
|
|
|
printf("\n************\n");
|
|
|
|
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
|
|
|
|
printf("************\n\n");
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
if (params.embedding) {
|
|
|
|
printf("\n************\n");
|
|
|
|
printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
|
|
|
|
printf("************\n\n");
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2023-03-18 17:10:47 -07:00
|
|
|
if (params.n_ctx > 2048) {
|
|
|
|
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
|
|
|
|
"expect poor results\n", __func__, params.n_ctx);
|
|
|
|
}
|
2023-03-10 20:40:58 +02:00
|
|
|
|
2023-03-22 07:47:15 +02:00
|
|
|
if (params.seed <= 0) {
|
2023-03-10 20:40:58 +02:00
|
|
|
params.seed = time(NULL);
|
|
|
|
}
|
|
|
|
|
2023-03-13 17:39:56 +01:00
|
|
|
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
|
2023-03-10 20:40:58 +02:00
|
|
|
|
|
|
|
std::mt19937 rng(params.seed);
|
2023-03-19 19:36:19 +01:00
|
|
|
if (params.random_prompt) {
|
2023-03-10 20:40:58 +02:00
|
|
|
params.prompt = gpt_random_prompt(rng);
|
|
|
|
}
|
|
|
|
|
2023-03-10 23:46:39 +02:00
|
|
|
// params.prompt = R"(// this function checks if the number n is prime
|
|
|
|
//bool is_prime(int n) {)";
|
2023-04-22 10:54:13 +00:00
|
|
|
|
2023-03-22 07:32:36 +02:00
|
|
|
llama_context * ctx;
|
2023-04-22 16:56:35 +08:00
|
|
|
g_ctx = &ctx;
|
2023-03-10 20:40:58 +02:00
|
|
|
|
|
|
|
// load the model
|
|
|
|
{
|
2023-03-22 07:32:36 +02:00
|
|
|
auto lparams = llama_context_default_params();
|
|
|
|
|
2023-03-22 07:45:00 +02:00
|
|
|
lparams.n_ctx = params.n_ctx;
|
|
|
|
lparams.n_parts = params.n_parts;
|
|
|
|
lparams.seed = params.seed;
|
|
|
|
lparams.f16_kv = params.memory_f16;
|
Rewrite loading code to try to satisfy everyone:
- Support all three formats (ggml, ggmf, ggjt). (However, I didn't
include the hack needed to support GPT4All files without conversion.
Those can still be used after converting them with convert.py from my
other PR.)
- Support both mmap and read (mmap is used by default, but can be
disabled with `--no-mmap`, and is automatically disabled for pre-ggjt
files or on platforms where mmap is not supported).
- Support multi-file models like before, but automatically determine the
number of parts rather than requiring `--n_parts`.
- Improve validation and error checking.
- Stop using the per-file type field (f16) entirely in favor of just
relying on the per-tensor type/size fields. This has no immediate
benefit, but makes it easier to experiment with different formats, and
should make it easier to support the new GPTQ-for-LLaMa models in the
future (I have some work in progress on that front).
- Support VirtualLock on Windows (using the same `--mlock` option as on
Unix).
- Indicate loading progress when using mmap + mlock. (Which led me
to the interesting observation that on my Linux machine, with a
warm file cache, mlock actually takes some time, whereas mmap
without mlock starts almost instantly...)
- To help implement this, move mlock support from ggml to the
loading code.
- madvise/PrefetchVirtualMemory support (based on #740)
- Switch from ifstream to the `fopen` family of functions to avoid
unnecessary copying and, when mmap is enabled, allow reusing the same
file descriptor for both metadata reads and mmap (whereas the existing
implementation opens the file a second time to mmap).
- Quantization now produces a single-file output even with multi-file
inputs (not really a feature as much as 'it was easier this way').
Implementation notes:
I tried to factor the code into more discrete pieces than before.
Regarding code style: I tried to follow the code style, but I'm naughty
and used a few advanced C++ features repeatedly:
- Destructors to make it easier to ensure everything gets cleaned up.
- Exceptions. I don't even usually use exceptions when writing C++, and
I can remove them if desired... but here they make the loading code
much more succinct while still properly handling a variety of errors,
ranging from API calls failing to integer overflow and allocation
failure. The exceptions are converted to error codes at the
API boundary.)
Co-authored-by: Pavol Rusnak <pavol@rusnak.io> (for the bit I copied from #740)
2023-04-08 12:24:37 -07:00
|
|
|
lparams.use_mmap = params.use_mmap;
|
2023-03-24 08:19:05 -07:00
|
|
|
lparams.use_mlock = params.use_mlock;
|
2023-03-22 07:32:36 +02:00
|
|
|
|
|
|
|
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
|
|
|
|
|
|
|
if (ctx == NULL) {
|
|
|
|
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
|
2023-03-10 20:40:58 +02:00
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-17 17:28:55 +02:00
|
|
|
if (!params.lora_adapter.empty()) {
|
|
|
|
int err = llama_apply_lora_from_file(ctx,
|
|
|
|
params.lora_adapter.c_str(),
|
|
|
|
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
|
|
|
|
params.n_threads);
|
|
|
|
if (err != 0) {
|
|
|
|
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-13 19:15:08 +02:00
|
|
|
// print system information
|
|
|
|
{
|
|
|
|
fprintf(stderr, "\n");
|
|
|
|
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
|
|
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
|
|
|
}
|
|
|
|
|
2023-03-24 23:17:37 +02:00
|
|
|
// determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
|
|
|
|
// uncomment the "used_mem" line in llama.cpp to see the results
|
|
|
|
if (params.mem_test) {
|
|
|
|
{
|
|
|
|
const std::vector<llama_token> tmp(params.n_batch, 0);
|
|
|
|
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
|
|
|
}
|
|
|
|
|
|
|
|
{
|
|
|
|
const std::vector<llama_token> tmp = { 0, };
|
|
|
|
llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
|
|
|
|
}
|
|
|
|
|
|
|
|
llama_print_timings(ctx);
|
|
|
|
llama_free(ctx);
|
|
|
|
|
|
|
|
return 0;
|
2023-03-22 07:32:36 +02:00
|
|
|
}
|
2023-03-21 09:27:42 -07:00
|
|
|
|
2023-03-17 21:05:58 +01:00
|
|
|
// Add a space in front of the first character to match OG llama tokenizer behavior
|
|
|
|
params.prompt.insert(0, 1, ' ');
|
2023-03-22 07:32:36 +02:00
|
|
|
|
2023-04-28 11:59:37 -04:00
|
|
|
std::string path_session = params.path_session;
|
|
|
|
std::vector<llama_token> session_tokens;
|
|
|
|
|
|
|
|
if (!path_session.empty()) {
|
|
|
|
fprintf(stderr, "%s: attempting to load saved session from %s..\n", __func__, path_session.c_str());
|
|
|
|
|
|
|
|
// REVIEW - fopen to check for existing session
|
|
|
|
FILE * fp = std::fopen(path_session.c_str(), "rb");
|
|
|
|
if (fp != NULL) {
|
|
|
|
std::fclose(fp);
|
|
|
|
|
|
|
|
session_tokens.resize(params.n_ctx);
|
|
|
|
size_t n_token_count_out = 0;
|
|
|
|
const size_t n_session_bytes = llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
|
|
|
|
session_tokens.resize(n_token_count_out);
|
|
|
|
|
|
|
|
if (n_session_bytes > 0) {
|
|
|
|
fprintf(stderr, "%s: loaded %zu bytes of session data!\n", __func__, n_session_bytes);
|
|
|
|
} else {
|
|
|
|
fprintf(stderr, "%s: could not load session file, will recreate\n", __func__);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-10 20:40:58 +02:00
|
|
|
// tokenize the prompt
|
2023-03-22 07:32:36 +02:00
|
|
|
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
2023-03-10 20:40:58 +02:00
|
|
|
|
2023-03-22 07:32:36 +02:00
|
|
|
const int n_ctx = llama_n_ctx(ctx);
|
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
if ((int) embd_inp.size() > n_ctx - 4) {
|
|
|
|
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
2023-04-28 11:59:37 -04:00
|
|
|
// debug message about similarity of saved session, if applicable
|
|
|
|
size_t n_matching_session_tokens = 0;
|
|
|
|
if (session_tokens.size()) {
|
|
|
|
for (llama_token id : session_tokens) {
|
|
|
|
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
n_matching_session_tokens++;
|
|
|
|
}
|
|
|
|
if (n_matching_session_tokens >= embd_inp.size()) {
|
|
|
|
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
|
|
|
|
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
|
|
|
|
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
|
|
|
|
__func__, n_matching_session_tokens, embd_inp.size());
|
|
|
|
} else {
|
|
|
|
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
|
|
|
|
__func__, n_matching_session_tokens, embd_inp.size());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-28 17:09:55 +03:00
|
|
|
// number of tokens to keep when resetting context
|
2023-04-14 21:58:43 +02:00
|
|
|
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
|
2023-03-28 17:09:55 +03:00
|
|
|
params.n_keep = (int)embd_inp.size();
|
|
|
|
}
|
2023-03-10 20:40:58 +02:00
|
|
|
|
2023-03-19 18:37:02 +02:00
|
|
|
// prefix & suffix for instruct mode
|
2023-04-14 21:58:43 +02:00
|
|
|
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
|
|
|
|
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
|
|
|
|
|
|
|
|
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
|
|
|
if (params.instruct) {
|
2023-04-24 17:45:32 +02:00
|
|
|
params.interactive_first = true;
|
2023-04-14 21:58:43 +02:00
|
|
|
params.antiprompt.push_back("### Instruction:\n\n");
|
2023-03-19 18:37:02 +02:00
|
|
|
}
|
|
|
|
|
2023-03-28 17:09:55 +03:00
|
|
|
// enable interactive mode if reverse prompt or interactive start is specified
|
2023-04-24 17:45:32 +02:00
|
|
|
if (params.antiprompt.size() != 0 || params.interactive_first) {
|
2023-03-22 18:16:35 +01:00
|
|
|
params.interactive = true;
|
|
|
|
}
|
|
|
|
|
2023-03-23 15:22:47 -05:00
|
|
|
// determine newline token
|
|
|
|
auto llama_token_newline = ::llama_tokenize(ctx, "\n", false);
|
|
|
|
|
2023-03-25 17:16:50 +02:00
|
|
|
if (params.verbose_prompt) {
|
|
|
|
fprintf(stderr, "\n");
|
|
|
|
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
|
|
|
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
|
|
|
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
|
|
|
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
|
|
|
|
}
|
2023-03-25 21:36:22 +02:00
|
|
|
if (params.n_keep > 0) {
|
|
|
|
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
|
|
|
|
for (int i = 0; i < params.n_keep; i++) {
|
|
|
|
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]));
|
|
|
|
}
|
|
|
|
fprintf(stderr, "'\n");
|
|
|
|
}
|
2023-03-25 17:16:50 +02:00
|
|
|
fprintf(stderr, "\n");
|
2023-03-10 20:40:58 +02:00
|
|
|
}
|
2023-03-25 17:16:50 +02:00
|
|
|
|
2023-03-12 22:13:28 +01:00
|
|
|
if (params.interactive) {
|
2023-03-13 04:08:01 +01:00
|
|
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
2023-03-12 22:13:28 +01:00
|
|
|
struct sigaction sigint_action;
|
|
|
|
sigint_action.sa_handler = sigint_handler;
|
|
|
|
sigemptyset (&sigint_action.sa_mask);
|
2023-03-13 19:15:08 +02:00
|
|
|
sigint_action.sa_flags = 0;
|
2023-03-12 22:13:28 +01:00
|
|
|
sigaction(SIGINT, &sigint_action, NULL);
|
2023-03-15 13:56:24 -06:00
|
|
|
#elif defined (_WIN32)
|
|
|
|
signal(SIGINT, sigint_handler);
|
2023-03-13 04:08:01 +01:00
|
|
|
#endif
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-13 17:39:56 +01:00
|
|
|
fprintf(stderr, "%s: interactive mode on.\n", __func__);
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
if (params.antiprompt.size()) {
|
2023-03-21 17:04:43 +01:00
|
|
|
for (auto antiprompt : params.antiprompt) {
|
|
|
|
fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
|
2023-03-12 22:13:28 +01:00
|
|
|
}
|
|
|
|
}
|
2023-03-25 14:03:19 +02:00
|
|
|
|
|
|
|
if (!params.input_prefix.empty()) {
|
|
|
|
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
|
|
|
|
}
|
2023-03-12 22:13:28 +01:00
|
|
|
}
|
llama : new sampling algorithms (#1126)
* Sample interface, new samplers.
New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat
Ignore EOS fix: -inf should be used.
* mirostat
* Added --logit-bias and --no-penalize-nl, removed std::span
* Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
* Save and load example adjust
* Tests
* Windows build fix
* Windows test fix
2023-04-29 08:34:41 +03:00
|
|
|
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
|
|
|
|
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
2023-03-25 21:36:22 +02:00
|
|
|
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
2023-03-13 17:39:56 +01:00
|
|
|
fprintf(stderr, "\n\n");
|
2023-03-10 20:40:58 +02:00
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
// TODO: replace with ring-buffer
|
|
|
|
std::vector<llama_token> last_n_tokens(n_ctx);
|
2023-03-12 05:27:42 -04:00
|
|
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
|
|
|
|
2023-03-12 22:13:28 +01:00
|
|
|
if (params.interactive) {
|
2023-03-13 17:39:56 +01:00
|
|
|
fprintf(stderr, "== Running in interactive mode. ==\n"
|
2023-03-15 13:56:24 -06:00
|
|
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
2023-03-12 22:13:28 +01:00
|
|
|
" - Press Ctrl+C to interject at any time.\n"
|
2023-03-13 04:08:01 +01:00
|
|
|
#endif
|
2023-04-14 21:58:43 +02:00
|
|
|
" - Press Return to return control to LLaMa.\n"
|
|
|
|
" - If you want to submit another line, end your input in '\\'.\n\n");
|
2023-04-24 17:45:32 +02:00
|
|
|
is_interacting = params.interactive_first;
|
2023-03-12 22:13:28 +01:00
|
|
|
}
|
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
bool is_antiprompt = false;
|
2023-03-28 17:09:55 +03:00
|
|
|
bool input_noecho = false;
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-04-28 11:59:37 -04:00
|
|
|
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
|
|
|
|
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
|
|
|
|
// initial prompt so it doesn't need to be an exact match.
|
|
|
|
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
|
|
|
|
|
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
int n_past = 0;
|
|
|
|
int n_remain = params.n_predict;
|
|
|
|
int n_consumed = 0;
|
2023-04-28 11:59:37 -04:00
|
|
|
int n_session_consumed = 0;
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-21 18:11:01 +01:00
|
|
|
// the first thing we will do is to output the prompt, so set color accordingly
|
2023-03-28 17:09:55 +03:00
|
|
|
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
std::vector<llama_token> embd;
|
|
|
|
|
2023-03-25 21:51:41 +02:00
|
|
|
while (n_remain != 0 || params.interactive) {
|
2023-03-10 20:40:58 +02:00
|
|
|
// predict
|
|
|
|
if (embd.size() > 0) {
|
2023-03-25 21:36:22 +02:00
|
|
|
// infinite text generation via context swapping
|
|
|
|
// if we run out of context:
|
|
|
|
// - take the n_keep first tokens from the original prompt (via n_past)
|
2023-04-21 11:18:09 -07:00
|
|
|
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
2023-03-25 21:36:22 +02:00
|
|
|
if (n_past + (int) embd.size() > n_ctx) {
|
|
|
|
const int n_left = n_past - params.n_keep;
|
|
|
|
|
|
|
|
n_past = params.n_keep;
|
|
|
|
|
|
|
|
// insert n_left/2 tokens at the start of embd from last_n_tokens
|
|
|
|
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
|
|
|
|
|
2023-04-28 11:59:37 -04:00
|
|
|
// REVIEW - stop saving session if we run out of context
|
|
|
|
path_session = "";
|
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
//printf("\n---\n");
|
|
|
|
//printf("resetting: '");
|
|
|
|
//for (int i = 0; i < (int) embd.size(); i++) {
|
|
|
|
// printf("%s", llama_token_to_str(ctx, embd[i]));
|
|
|
|
//}
|
|
|
|
//printf("'\n");
|
|
|
|
//printf("\n---\n");
|
|
|
|
}
|
|
|
|
|
2023-04-28 11:59:37 -04:00
|
|
|
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
|
|
|
// REVIEW
|
|
|
|
if (n_session_consumed < (int) session_tokens.size()) {
|
|
|
|
size_t i = 0;
|
|
|
|
for ( ; i < embd.size(); i++) {
|
|
|
|
if (embd[i] != session_tokens[n_session_consumed]) {
|
|
|
|
session_tokens.resize(n_session_consumed);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
n_past++;
|
|
|
|
n_session_consumed++;
|
|
|
|
|
|
|
|
if (n_session_consumed >= (int) session_tokens.size()) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (i > 0) {
|
|
|
|
embd.erase(embd.begin(), embd.begin() + i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-21 11:18:09 -07:00
|
|
|
// evaluate tokens in batches
|
|
|
|
// embd is typically prepared beforehand to fit within a batch, but not always
|
|
|
|
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
|
|
|
int n_eval = (int) embd.size() - i;
|
|
|
|
if (n_eval > params.n_batch) {
|
|
|
|
n_eval = params.n_batch;
|
|
|
|
}
|
|
|
|
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
|
|
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
n_past += n_eval;
|
2023-03-10 20:40:58 +02:00
|
|
|
}
|
2023-04-28 11:59:37 -04:00
|
|
|
|
|
|
|
if (embd.size() > 0 && !path_session.empty()) {
|
|
|
|
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
|
|
|
n_session_consumed = session_tokens.size();
|
|
|
|
}
|
2023-03-10 20:40:58 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
embd.clear();
|
|
|
|
|
2023-03-25 21:36:22 +02:00
|
|
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
2023-03-12 23:07:34 +01:00
|
|
|
// out of user input, sample next token
|
2023-04-29 09:51:06 +03:00
|
|
|
const float temp = params.temp;
|
|
|
|
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
|
|
|
const float top_p = params.top_p;
|
|
|
|
const float tfs_z = params.tfs_z;
|
|
|
|
const float typical_p = params.typical_p;
|
|
|
|
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
|
|
|
const float repeat_penalty = params.repeat_penalty;
|
|
|
|
const float alpha_presence = params.presence_penalty;
|
llama : new sampling algorithms (#1126)
* Sample interface, new samplers.
New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat
Ignore EOS fix: -inf should be used.
* mirostat
* Added --logit-bias and --no-penalize-nl, removed std::span
* Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
* Save and load example adjust
* Tests
* Windows build fix
* Windows test fix
2023-04-29 08:34:41 +03:00
|
|
|
const float alpha_frequency = params.frequency_penalty;
|
2023-04-29 09:51:06 +03:00
|
|
|
const int mirostat = params.mirostat;
|
|
|
|
const float mirostat_tau = params.mirostat_tau;
|
|
|
|
const float mirostat_eta = params.mirostat_eta;
|
|
|
|
const bool penalize_nl = params.penalize_nl;
|
2023-03-10 20:40:58 +02:00
|
|
|
|
2023-04-28 11:59:37 -04:00
|
|
|
// optionally save the session on first sample (for faster prompt loading next time)
|
|
|
|
if (!path_session.empty() && need_to_save_session) {
|
|
|
|
need_to_save_session = false;
|
|
|
|
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
|
|
|
}
|
|
|
|
|
2023-03-22 07:32:36 +02:00
|
|
|
llama_token id = 0;
|
2023-03-10 20:40:58 +02:00
|
|
|
|
|
|
|
{
|
2023-03-22 07:32:36 +02:00
|
|
|
auto logits = llama_get_logits(ctx);
|
llama : new sampling algorithms (#1126)
* Sample interface, new samplers.
New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat
Ignore EOS fix: -inf should be used.
* mirostat
* Added --logit-bias and --no-penalize-nl, removed std::span
* Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
* Save and load example adjust
* Tests
* Windows build fix
* Windows test fix
2023-04-29 08:34:41 +03:00
|
|
|
auto n_vocab = llama_n_vocab(ctx);
|
|
|
|
|
|
|
|
// Apply params.logit_bias map
|
|
|
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
|
|
|
logits[it->first] += it->second;
|
|
|
|
}
|
2023-03-10 20:40:58 +02:00
|
|
|
|
llama : new sampling algorithms (#1126)
* Sample interface, new samplers.
New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat
Ignore EOS fix: -inf should be used.
* mirostat
* Added --logit-bias and --no-penalize-nl, removed std::span
* Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
* Save and load example adjust
* Tests
* Windows build fix
* Windows test fix
2023-04-29 08:34:41 +03:00
|
|
|
std::vector<llama_token_data> candidates;
|
|
|
|
candidates.reserve(n_vocab);
|
|
|
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
|
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
2023-03-19 19:22:48 +01:00
|
|
|
}
|
|
|
|
|
llama : new sampling algorithms (#1126)
* Sample interface, new samplers.
New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat
Ignore EOS fix: -inf should be used.
* mirostat
* Added --logit-bias and --no-penalize-nl, removed std::span
* Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
* Save and load example adjust
* Tests
* Windows build fix
* Windows test fix
2023-04-29 08:34:41 +03:00
|
|
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
|
|
|
|
|
|
// Apply penalties
|
|
|
|
float nl_logit = logits[llama_token_nl()];
|
|
|
|
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
|
|
|
llama_sample_repetition_penalty(ctx, &candidates_p,
|
|
|
|
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
|
|
last_n_repeat, repeat_penalty);
|
|
|
|
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
|
|
|
|
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
|
|
last_n_repeat, alpha_frequency, alpha_presence);
|
|
|
|
if (!penalize_nl) {
|
|
|
|
logits[llama_token_nl()] = nl_logit;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (temp <= 0) {
|
|
|
|
// Greedy sampling
|
|
|
|
id = llama_sample_token_greedy(ctx, &candidates_p);
|
|
|
|
} else {
|
|
|
|
if (mirostat == 1) {
|
|
|
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
|
|
const int mirostat_m = 100;
|
|
|
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
|
|
|
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
|
|
|
} else if (mirostat == 2) {
|
|
|
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
|
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
|
|
|
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
|
|
|
} else {
|
|
|
|
// Temperature sampling
|
|
|
|
llama_sample_top_k(ctx, &candidates_p, top_k);
|
|
|
|
llama_sample_tail_free(ctx, &candidates_p, tfs_z);
|
|
|
|
llama_sample_typical(ctx, &candidates_p, typical_p);
|
|
|
|
llama_sample_top_p(ctx, &candidates_p, top_p);
|
|
|
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
|
|
|
id = llama_sample_token(ctx, &candidates_p);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// printf("`%d`", candidates_p.size);
|
2023-03-12 05:27:42 -04:00
|
|
|
|
|
|
|
last_n_tokens.erase(last_n_tokens.begin());
|
|
|
|
last_n_tokens.push_back(id);
|
2023-03-10 20:40:58 +02:00
|
|
|
}
|
|
|
|
|
2023-03-23 15:22:47 -05:00
|
|
|
// replace end of text token with newline token when in interactive mode
|
2023-04-14 21:58:43 +02:00
|
|
|
if (id == llama_token_eos() && params.interactive && !params.instruct) {
|
2023-03-23 15:22:47 -05:00
|
|
|
id = llama_token_newline.front();
|
|
|
|
if (params.antiprompt.size() != 0) {
|
|
|
|
// tokenize and inject first reverse prompt
|
|
|
|
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
|
|
|
|
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-10 20:40:58 +02:00
|
|
|
// add it to the context
|
|
|
|
embd.push_back(id);
|
2023-03-12 22:13:28 +01:00
|
|
|
|
|
|
|
// echo this to console
|
|
|
|
input_noecho = false;
|
|
|
|
|
|
|
|
// decrement remaining sampling budget
|
2023-03-25 21:36:22 +02:00
|
|
|
--n_remain;
|
2023-03-10 20:40:58 +02:00
|
|
|
} else {
|
2023-03-12 23:07:34 +01:00
|
|
|
// some user input remains from prompt or interaction, forward it to processing
|
2023-03-25 21:36:22 +02:00
|
|
|
while ((int) embd_inp.size() > n_consumed) {
|
|
|
|
embd.push_back(embd_inp[n_consumed]);
|
2023-03-12 05:27:42 -04:00
|
|
|
last_n_tokens.erase(last_n_tokens.begin());
|
2023-03-25 21:36:22 +02:00
|
|
|
last_n_tokens.push_back(embd_inp[n_consumed]);
|
|
|
|
++n_consumed;
|
2023-03-19 19:46:32 +02:00
|
|
|
if ((int) embd.size() >= params.n_batch) {
|
2023-03-10 20:40:58 +02:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// display text
|
2023-03-12 22:13:28 +01:00
|
|
|
if (!input_noecho) {
|
|
|
|
for (auto id : embd) {
|
2023-03-22 07:32:36 +02:00
|
|
|
printf("%s", llama_token_to_str(ctx, id));
|
2023-03-12 22:13:28 +01:00
|
|
|
}
|
|
|
|
fflush(stdout);
|
|
|
|
}
|
2023-03-19 13:44:30 -06:00
|
|
|
// reset color to default if we there is no pending user input
|
2023-03-25 21:36:22 +02:00
|
|
|
if (!input_noecho && (int)embd_inp.size() == n_consumed) {
|
2023-03-28 17:09:55 +03:00
|
|
|
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
2023-03-19 13:44:30 -06:00
|
|
|
}
|
2023-03-12 22:13:28 +01:00
|
|
|
|
|
|
|
// in interactive mode, and not currently processing queued inputs;
|
|
|
|
// check if we should prompt the user for more
|
2023-03-25 21:36:22 +02:00
|
|
|
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
|
2023-03-28 17:09:55 +03:00
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
// check for reverse prompt
|
|
|
|
if (params.antiprompt.size()) {
|
2023-03-28 17:09:55 +03:00
|
|
|
std::string last_output;
|
|
|
|
for (auto id : last_n_tokens) {
|
|
|
|
last_output += llama_token_to_str(ctx, id);
|
|
|
|
}
|
2023-03-21 17:04:43 +01:00
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
is_antiprompt = false;
|
2023-03-28 17:09:55 +03:00
|
|
|
// Check if each of the reverse prompts appears at the end of the output.
|
2023-04-14 21:58:43 +02:00
|
|
|
for (std::string & antiprompt : params.antiprompt) {
|
|
|
|
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
2023-03-28 17:09:55 +03:00
|
|
|
is_interacting = true;
|
2023-04-14 21:58:43 +02:00
|
|
|
is_antiprompt = true;
|
2023-03-28 17:09:55 +03:00
|
|
|
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
|
|
|
fflush(stdout);
|
|
|
|
break;
|
|
|
|
}
|
2023-03-19 20:33:06 +01:00
|
|
|
}
|
2023-03-12 22:13:28 +01:00
|
|
|
}
|
2023-03-24 23:17:58 +02:00
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
if (n_past > 0 && is_interacting) {
|
2023-03-21 18:11:01 +01:00
|
|
|
// potentially set color to indicate we are taking user input
|
2023-03-28 17:09:55 +03:00
|
|
|
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
2023-03-21 18:11:01 +01:00
|
|
|
|
2023-04-03 18:00:55 +02:00
|
|
|
#if defined (_WIN32)
|
|
|
|
// Windows: must reactivate sigint handler after each signal
|
|
|
|
signal(SIGINT, sigint_handler);
|
|
|
|
#endif
|
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
if (params.instruct) {
|
2023-03-19 18:37:02 +02:00
|
|
|
printf("\n> ");
|
|
|
|
}
|
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
std::string buffer;
|
2023-03-25 14:03:19 +02:00
|
|
|
if (!params.input_prefix.empty()) {
|
|
|
|
buffer += params.input_prefix;
|
2023-03-25 16:22:05 +02:00
|
|
|
printf("%s", buffer.c_str());
|
2023-03-25 14:03:19 +02:00
|
|
|
}
|
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
std::string line;
|
|
|
|
bool another_line = true;
|
|
|
|
do {
|
|
|
|
#if defined(_WIN32)
|
|
|
|
std::wstring wline;
|
|
|
|
if (!std::getline(std::wcin, wline)) {
|
|
|
|
// input stream is bad or EOF received
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
win32_utf8_encode(wline, line);
|
|
|
|
#else
|
|
|
|
if (!std::getline(std::cin, line)) {
|
|
|
|
// input stream is bad or EOF received
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
if (line.empty() || line.back() != '\\') {
|
|
|
|
another_line = false;
|
|
|
|
} else {
|
|
|
|
line.pop_back(); // Remove the continue character
|
|
|
|
}
|
|
|
|
buffer += line + '\n'; // Append the line to the result
|
|
|
|
} while (another_line);
|
2023-03-21 18:11:01 +01:00
|
|
|
|
|
|
|
// done taking input, reset color
|
2023-03-28 17:09:55 +03:00
|
|
|
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
2023-03-19 18:37:02 +02:00
|
|
|
|
2023-03-28 17:09:55 +03:00
|
|
|
// Add tokens to embd only if the input buffer is non-empty
|
|
|
|
// Entering a empty line lets the user pass control back
|
|
|
|
if (buffer.length() > 1) {
|
2023-03-13 00:35:51 +01:00
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
// instruct mode: insert instruction prefix
|
|
|
|
if (params.instruct && !is_antiprompt) {
|
2023-03-28 17:09:55 +03:00
|
|
|
n_consumed = embd_inp.size();
|
|
|
|
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
|
|
|
|
}
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-28 17:09:55 +03:00
|
|
|
auto line_inp = ::llama_tokenize(ctx, buffer, false);
|
|
|
|
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
|
|
|
|
2023-04-14 21:58:43 +02:00
|
|
|
// instruct mode: insert response suffix
|
|
|
|
if (params.instruct) {
|
2023-03-28 17:09:55 +03:00
|
|
|
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
|
|
|
}
|
|
|
|
|
|
|
|
n_remain -= line_inp.size();
|
|
|
|
}
|
2023-03-19 13:44:30 -06:00
|
|
|
|
|
|
|
input_noecho = true; // do not echo this again
|
2023-03-12 22:13:28 +01:00
|
|
|
}
|
2023-03-24 23:17:58 +02:00
|
|
|
|
|
|
|
if (n_past > 0) {
|
|
|
|
is_interacting = false;
|
|
|
|
}
|
2023-03-10 20:40:58 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
// end of text token
|
2023-04-06 17:59:11 +02:00
|
|
|
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
2023-04-14 21:58:43 +02:00
|
|
|
if (params.instruct) {
|
2023-03-24 10:22:39 -05:00
|
|
|
is_interacting = true;
|
|
|
|
} else {
|
|
|
|
fprintf(stderr, " [end of text]\n");
|
|
|
|
break;
|
|
|
|
}
|
2023-03-10 20:40:58 +02:00
|
|
|
}
|
2023-03-19 19:31:17 +01:00
|
|
|
|
|
|
|
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
|
2023-03-26 16:06:10 +03:00
|
|
|
if (params.interactive && n_remain <= 0 && params.n_predict != -1) {
|
2023-03-25 21:36:22 +02:00
|
|
|
n_remain = params.n_predict;
|
2023-03-19 19:31:17 +01:00
|
|
|
is_interacting = true;
|
|
|
|
}
|
2023-03-10 20:40:58 +02:00
|
|
|
}
|
|
|
|
|
2023-03-15 13:56:24 -06:00
|
|
|
#if defined (_WIN32)
|
|
|
|
signal(SIGINT, SIG_DFL);
|
|
|
|
#endif
|
2023-03-12 22:13:28 +01:00
|
|
|
|
2023-03-22 07:32:36 +02:00
|
|
|
llama_print_timings(ctx);
|
|
|
|
llama_free(ctx);
|
2023-03-10 20:40:58 +02:00
|
|
|
|
2023-03-28 17:09:55 +03:00
|
|
|
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
|
2023-03-15 15:39:38 -04:00
|
|
|
|
2023-03-10 20:40:58 +02:00
|
|
|
return 0;
|
|
|
|
}
|