mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 04:23:06 +01:00
mtl : export the LLaMA computation graph
This commit is contained in:
parent
7552ac5863
commit
f85020b19a
@ -37,6 +37,7 @@ else()
|
||||
add_subdirectory(save-load-state)
|
||||
add_subdirectory(benchmark)
|
||||
add_subdirectory(baby-llama)
|
||||
add_subdirectory(mtl)
|
||||
if(LLAMA_BUILD_SERVER)
|
||||
add_subdirectory(server)
|
||||
endif()
|
||||
|
7
examples/mtl/CMakeLists.txt
Normal file
7
examples/mtl/CMakeLists.txt
Normal file
@ -0,0 +1,7 @@
|
||||
set(TARGET mtl-export)
|
||||
add_executable(${TARGET} mtl-export.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
25
examples/mtl/mtl-export.cpp
Normal file
25
examples/mtl/mtl-export.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_init_backend();
|
||||
|
||||
llama_context * ctx = llama_init_from_gpt_params(params);
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr, "%s: error: unable to load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_eval_export(ctx, "llama.ggml");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
44
llama.cpp
44
llama.cpp
@ -1189,17 +1189,19 @@ static bool llama_model_load(
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - tokens: new batch of tokens to process
|
||||
// - n_past: the context size so far
|
||||
// - n_threads: number of threads to use
|
||||
// - lctx: llama context
|
||||
// - tokens: new batch of tokens to process
|
||||
// - n_past: the context size so far
|
||||
// - n_threads: number of threads to use
|
||||
// - cgraph_fname: filename of the exported computation graph (TODO: TMP!!!)
|
||||
//
|
||||
static bool llama_eval_internal(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads) {
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads,
|
||||
const char * cgraph_fname) {
|
||||
|
||||
// enforce that the first token is BOS
|
||||
if (n_past == 0 && tokens[0] != llama_token_bos()) {
|
||||
@ -1422,6 +1424,10 @@ static bool llama_eval_internal(
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
if (cgraph_fname) {
|
||||
ggml_graph_export(&gf, cgraph_fname);
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
// requires GGML_PERF to be defined
|
||||
@ -2899,7 +2905,7 @@ int llama_eval(
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -2914,6 +2920,24 @@ int llama_eval(
|
||||
return 0;
|
||||
}
|
||||
|
||||
int llama_eval_export(struct llama_context * ctx, const char * fname) {
|
||||
// these values determine the maximum inference sizes of the exported computation graph
|
||||
// TODO: TMP !!!
|
||||
//const int n_ctx = ctx->model.hparams.n_ctx;
|
||||
//const int n_batch = 512;
|
||||
const int n_ctx = 128;
|
||||
const int n_batch = 32;
|
||||
|
||||
const std::vector<llama_token> tmp(n_batch, llama_token_bos());
|
||||
|
||||
if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
|
4
llama.h
4
llama.h
@ -173,6 +173,10 @@ extern "C" {
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Export a computation graph for model inference
|
||||
// TODO: very likely to change
|
||||
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
|
Loading…
Reference in New Issue
Block a user