#include "llama.h" #include <cstdio> #include <cstring> #include <iostream> #include <string> #include <vector> static void print_usage(int, char ** argv) { printf("\nexample usage:\n"); printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); printf("\n"); } int main(int argc, char ** argv) { std::string model_path; int ngl = 99; int n_ctx = 2048; // parse command line arguments for (int i = 1; i < argc; i++) { try { if (strcmp(argv[i], "-m") == 0) { if (i + 1 < argc) { model_path = argv[++i]; } else { print_usage(argc, argv); return 1; } } else if (strcmp(argv[i], "-c") == 0) { if (i + 1 < argc) { n_ctx = std::stoi(argv[++i]); } else { print_usage(argc, argv); return 1; } } else if (strcmp(argv[i], "-ngl") == 0) { if (i + 1 < argc) { ngl = std::stoi(argv[++i]); } else { print_usage(argc, argv); return 1; } } else { print_usage(argc, argv); return 1; } } catch (std::exception & e) { fprintf(stderr, "error: %s\n", e.what()); print_usage(argc, argv); return 1; } } if (model_path.empty()) { print_usage(argc, argv); return 1; } // only print errors llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) { if (level >= GGML_LOG_LEVEL_ERROR) { fprintf(stderr, "%s", text); } }, nullptr); // initialize the model llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); if (!model) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } // initialize the context llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = n_ctx; ctx_params.n_batch = n_ctx; llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (!ctx) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; } // initialize the sampler llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); // helper function to evaluate a prompt and generate a response auto generate = [&](const std::string & prompt) { std::string response; // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); std::vector<llama_token> prompt_tokens(n_prompt_tokens); if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { GGML_ABORT("failed to tokenize the prompt\n"); } // prepare a batch for the prompt llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_get_kv_cache_used_cells(ctx); if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); exit(0); } if (llama_decode(ctx, batch)) { GGML_ABORT("failed to decode\n"); } // sample the next token new_token_id = llama_sampler_sample(smpl, ctx, -1); // is it an end of generation? if (llama_token_is_eog(model, new_token_id)) { break; } // convert the token to a string, print it and add it to the response char buf[256]; int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); if (n < 0) { GGML_ABORT("failed to convert token to piece\n"); } std::string piece(buf, n); printf("%s", piece.c_str()); fflush(stdout); response += piece; // prepare the next batch with the sampled token batch = llama_batch_get_one(&new_token_id, 1); } return response; }; std::vector<llama_chat_message> messages; std::vector<char> formatted(llama_n_ctx(ctx)); int prev_len = 0; while (true) { // get user input printf("\033[32m> \033[0m"); std::string user; std::getline(std::cin, user); if (user.empty()) { break; } // add the user input to the message list and format it messages.push_back({"user", strdup(user.c_str())}); int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); if (new_len > (int)formatted.size()) { formatted.resize(new_len); new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); } if (new_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; } // remove previous messages to obtain the prompt to generate the response std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); // generate a response printf("\033[33m"); std::string response = generate(prompt); printf("\n\033[0m"); // add the response to the messages messages.push_back({"assistant", strdup(response.c_str())}); prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); if (prev_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; } } // free resources for (auto & msg : messages) { free(const_cast<char *>(msg.content)); } llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model); return 0; }