From a6744e43e80f4be6398fc7733a01642c846dce1d Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Fri, 1 Nov 2024 23:50:59 +0100 Subject: [PATCH] llama : add simple-chat example (#10124) * llama : add simple-chat example --------- Co-authored-by: Xuan Son Nguyen --- Makefile | 6 + examples/CMakeLists.txt | 1 + examples/simple-chat/CMakeLists.txt | 5 + examples/simple-chat/README.md | 7 + examples/simple-chat/simple-chat.cpp | 197 +++++++++++++++++++++++++++ ggml/include/ggml.h | 8 +- 6 files changed, 220 insertions(+), 4 deletions(-) create mode 100644 examples/simple-chat/CMakeLists.txt create mode 100644 examples/simple-chat/README.md create mode 100644 examples/simple-chat/simple-chat.cpp diff --git a/Makefile b/Makefile index 719f45d16..051436344 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,7 @@ BUILD_TARGETS = \ llama-save-load-state \ llama-server \ llama-simple \ + llama-simple-chat \ llama-speculative \ llama-tokenize \ llama-vdot \ @@ -1287,6 +1288,11 @@ llama-simple: examples/simple/simple.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +llama-simple-chat: examples/simple-chat/simple-chat.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + llama-tokenize: examples/tokenize/tokenize.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index ead630661..6df318c19 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -49,6 +49,7 @@ else() endif() add_subdirectory(save-load-state) add_subdirectory(simple) + add_subdirectory(simple-chat) add_subdirectory(speculative) add_subdirectory(tokenize) endif() diff --git a/examples/simple-chat/CMakeLists.txt b/examples/simple-chat/CMakeLists.txt new file mode 100644 index 000000000..87723533b --- /dev/null +++ b/examples/simple-chat/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-simple-chat) +add_executable(${TARGET} simple-chat.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/simple-chat/README.md b/examples/simple-chat/README.md new file mode 100644 index 000000000..f0099ce3d --- /dev/null +++ b/examples/simple-chat/README.md @@ -0,0 +1,7 @@ +# llama.cpp/example/simple-chat + +The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file. + +```bash +./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048 +... diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp new file mode 100644 index 000000000..14264cfcb --- /dev/null +++ b/examples/simple-chat/simple-chat.cpp @@ -0,0 +1,197 @@ +#include "llama.h" +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); + printf("\n"); +} + +int main(int argc, char ** argv) { + std::string model_path; + int ngl = 99; + int n_ctx = 2048; + + // parse command line arguments + for (int i = 1; i < argc; i++) { + try { + if (strcmp(argv[i], "-m") == 0) { + if (i + 1 < argc) { + model_path = argv[++i]; + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-c") == 0) { + if (i + 1 < argc) { + n_ctx = std::stoi(argv[++i]); + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (i + 1 < argc) { + ngl = std::stoi(argv[++i]); + } else { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } catch (std::exception & e) { + fprintf(stderr, "error: %s\n", e.what()); + print_usage(argc, argv); + return 1; + } + } + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + + // only print errors + llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) { + if (level >= GGML_LOG_LEVEL_ERROR) { + fprintf(stderr, "%s", text); + } + }, nullptr); + + // initialize the model + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); + if (!model) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // initialize the context + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + if (!ctx) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // initialize the sampler + llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); + llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + + // helper function to evaluate a prompt and generate a response + auto generate = [&](const std::string & prompt) { + std::string response; + + // tokenize the prompt + const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); + std::vector prompt_tokens(n_prompt_tokens); + if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { + GGML_ABORT("failed to tokenize the prompt\n"); + } + + // prepare a batch for the prompt + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_token new_token_id; + while (true) { + // check if we have enough space in the context to evaluate this batch + int n_ctx = llama_n_ctx(ctx); + int n_ctx_used = llama_get_kv_cache_used_cells(ctx); + if (n_ctx_used + batch.n_tokens > n_ctx) { + printf("\033[0m\n"); + fprintf(stderr, "context size exceeded\n"); + exit(0); + } + + if (llama_decode(ctx, batch)) { + GGML_ABORT("failed to decode\n"); + } + + // sample the next token + new_token_id = llama_sampler_sample(smpl, ctx, -1); + + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id)) { + break; + } + + // convert the token to a string, print it and add it to the response + char buf[256]; + int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + GGML_ABORT("failed to convert token to piece\n"); + } + std::string piece(buf, n); + printf("%s", piece.c_str()); + fflush(stdout); + response += piece; + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + } + + return response; + }; + + std::vector messages; + std::vector formatted(llama_n_ctx(ctx)); + int prev_len = 0; + while (true) { + // get user input + printf("\033[32m> \033[0m"); + std::string user; + std::getline(std::cin, user); + + if (user.empty()) { + break; + } + + // add the user input to the message list and format it + messages.push_back({"user", strdup(user.c_str())}); + int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + if (new_len > (int)formatted.size()) { + formatted.resize(new_len); + new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + } + if (new_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + + // remove previous messages to obtain the prompt to generate the response + std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); + + // generate a response + printf("\033[33m"); + std::string response = generate(prompt); + printf("\n\033[0m"); + + // add the response to the messages + messages.push_back({"assistant", strdup(response.c_str())}); + prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); + if (prev_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + } + + // free resources + for (auto & msg : messages) { + free(const_cast(msg.content)); + } + llama_sampler_free(smpl); + llama_free(ctx); + llama_free_model(model); + + return 0; +} diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 41df85557..2d93f31fa 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -558,10 +558,10 @@ extern "C" { enum ggml_log_level { GGML_LOG_LEVEL_NONE = 0, - GGML_LOG_LEVEL_INFO = 1, - GGML_LOG_LEVEL_WARN = 2, - GGML_LOG_LEVEL_ERROR = 3, - GGML_LOG_LEVEL_DEBUG = 4, + GGML_LOG_LEVEL_DEBUG = 1, + GGML_LOG_LEVEL_INFO = 2, + GGML_LOG_LEVEL_WARN = 3, + GGML_LOG_LEVEL_ERROR = 4, GGML_LOG_LEVEL_CONT = 5, // continue previous log };