mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
ecb217db4f
* mtl : export the LLaMA computation graph
* ci : disable temporary
* mtl : adapt the MNIST example as starter
* mtl : no need for mtl-export tool, add cli arg for main instead
* mtl : export just a small part of the graph for now to make it easier
* mtl : move MSL code into separate file for easy editing
* mtl : initial get_rows_q4_0 kernel
* mtl : confirmed get_rows_q4_0 is working correctly
* mtl : add rms_norm kernel + confirm working
* mtl : add mul kernel + confirm working
* mtl : initial mul_mat Q4 kernel (wrong results)
* mtl : mul_mat fixes (still wrong)
* mtl : another mul_mat Q4 (still does not work)
* mtl : working mul_mat q4
* ggml : fix handling of "view" ops in ggml_graph_import()
* mtl : add rope kernel
* mtl : add reshape and transpose handling
* ggml : store offset as opt arg for ggml_view_xd() operators
* mtl : add cpy kernel + handle view ops
* mtl : confirm f16 x f32 attention mul mat
* mtl : add scale kernel
* mtl : add diag_mask_inf kernel
* mtl : fix soft_max kernel
* ggml : update ggml_nbytes() to handle non-contiguous tensors
* mtl : verify V tensor contents
* mtl : add f32 -> f32 cpy kernel
* mtl : add silu kernel
* mtl : add non-broadcast mul kernel
* mtl : full GPU inference of the computation graph
* mtl : optimize rms_norm and soft_max kernels
* mtl : add f16 mat x f32 vec multiplication kernel
* mtl : fix bug in f16 x f32 mul mat + speed-up computation
* mtl : faster mul_mat_q4_0_f32 kernel
* mtl : fix kernel signature + roll inner loop
* mtl : more threads for rms_norm + better timing
* mtl : remove printfs from inner loop
* mtl : simplify implementation
* mtl : add save/load vocab to ggml file
* mtl : plug Metal inference into llama.cpp (very quick-n-dirty)
* mtl : make it work with main example
Lots of hacks but at least now it generates text
* mtl : preparing for merge
* mtl : clean-up ggml mtl interface + suport scratch / inplace
* mtl : remove temp / debug code
* metal : final refactoring and simplification
* Revert "ci : disable temporary"
This reverts commit 98c267fc77
.
* metal : add comments
* metal : clean-up stuff, fix typos
* readme : add Metal instructions
* readme : add example for main
103 lines
2.7 KiB
C++
103 lines
2.7 KiB
C++
// Evaluate a statically exported ggml computation graph with Metal
|
|
//
|
|
// - First, export a LLaMA graph:
|
|
//
|
|
// $ ./bin/main -m ../models/7B/ggml-model-q4_0.bin --export
|
|
//
|
|
// - Run this tool to evaluate the exported graph:
|
|
//
|
|
// $ ./bin/metal llama.ggml
|
|
//
|
|
// The purpose of this tool is mostly for debugging and demonstration purposes.
|
|
// The main limitation of exporting computation graphs is that their sizes are static which often
|
|
// can be a problem for real-world applications.
|
|
//
|
|
|
|
#include "ggml.h"
|
|
#include "ggml-metal.h"
|
|
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <cstdlib>
|
|
|
|
int main(int argc, char ** argv) {
|
|
ggml_time_init();
|
|
|
|
if (argc != 2) {
|
|
fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]);
|
|
return -1;
|
|
}
|
|
|
|
const char * fname_cgraph = argv[1];
|
|
|
|
// load the compute graph
|
|
struct ggml_context * ctx_data = NULL;
|
|
struct ggml_context * ctx_eval = NULL;
|
|
|
|
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
|
|
gf.n_threads = 1;
|
|
|
|
// this allocates all Metal resources and memory buffers
|
|
auto * ctx_metal = ggml_metal_init();
|
|
|
|
ggml_metal_add_buffer(ctx_metal, "data", ggml_get_mem_buffer(ctx_data), ggml_get_mem_size(ctx_data));
|
|
ggml_metal_add_buffer(ctx_metal, "eval", ggml_get_mem_buffer(ctx_eval), ggml_get_mem_size(ctx_eval));
|
|
|
|
// main
|
|
{
|
|
struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
|
|
*(int32_t *) input->data = 1; // BOS
|
|
|
|
ggml_metal_set_tensor(ctx_metal, input);
|
|
|
|
// warmup
|
|
ggml_metal_graph_compute(ctx_metal, &gf);
|
|
|
|
const int n_iter = 16;
|
|
|
|
const int64_t t0 = ggml_time_us();
|
|
|
|
// the actual inference happens here
|
|
for (int i = 0; i < n_iter; ++i) {
|
|
ggml_metal_graph_compute(ctx_metal, &gf);
|
|
}
|
|
|
|
const int64_t t1 = ggml_time_us();
|
|
|
|
printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter);
|
|
}
|
|
|
|
// debug output
|
|
{
|
|
struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1];
|
|
ggml_metal_get_tensor(ctx_metal, logits);
|
|
|
|
float * ptr = (float *) ggml_get_data(logits);
|
|
|
|
printf("logits: ");
|
|
for (int i = 0; i < 10; i++) {
|
|
printf("%8.4f ", ptr[i]);
|
|
}
|
|
printf("\n");
|
|
int imax = 0;
|
|
double sum = 0.0;
|
|
double vmax = -1e9;
|
|
for (int i = 0; i < 32000; i++) {
|
|
sum += (double) ptr[i];
|
|
if (ptr[i] > vmax) {
|
|
vmax = ptr[i];
|
|
imax = i;
|
|
}
|
|
}
|
|
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
|
|
}
|
|
|
|
ggml_metal_free(ctx_metal);
|
|
|
|
ggml_free(ctx_data);
|
|
ggml_free(ctx_eval);
|
|
|
|
return 0;
|
|
}
|
|
|