mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-10-30 06:30:15 +01:00
ecb217db4f
* mtl : export the LLaMA computation graph
* ci : disable temporary
* mtl : adapt the MNIST example as starter
* mtl : no need for mtl-export tool, add cli arg for main instead
* mtl : export just a small part of the graph for now to make it easier
* mtl : move MSL code into separate file for easy editing
* mtl : initial get_rows_q4_0 kernel
* mtl : confirmed get_rows_q4_0 is working correctly
* mtl : add rms_norm kernel + confirm working
* mtl : add mul kernel + confirm working
* mtl : initial mul_mat Q4 kernel (wrong results)
* mtl : mul_mat fixes (still wrong)
* mtl : another mul_mat Q4 (still does not work)
* mtl : working mul_mat q4
* ggml : fix handling of "view" ops in ggml_graph_import()
* mtl : add rope kernel
* mtl : add reshape and transpose handling
* ggml : store offset as opt arg for ggml_view_xd() operators
* mtl : add cpy kernel + handle view ops
* mtl : confirm f16 x f32 attention mul mat
* mtl : add scale kernel
* mtl : add diag_mask_inf kernel
* mtl : fix soft_max kernel
* ggml : update ggml_nbytes() to handle non-contiguous tensors
* mtl : verify V tensor contents
* mtl : add f32 -> f32 cpy kernel
* mtl : add silu kernel
* mtl : add non-broadcast mul kernel
* mtl : full GPU inference of the computation graph
* mtl : optimize rms_norm and soft_max kernels
* mtl : add f16 mat x f32 vec multiplication kernel
* mtl : fix bug in f16 x f32 mul mat + speed-up computation
* mtl : faster mul_mat_q4_0_f32 kernel
* mtl : fix kernel signature + roll inner loop
* mtl : more threads for rms_norm + better timing
* mtl : remove printfs from inner loop
* mtl : simplify implementation
* mtl : add save/load vocab to ggml file
* mtl : plug Metal inference into llama.cpp (very quick-n-dirty)
* mtl : make it work with main example
Lots of hacks but at least now it generates text
* mtl : preparing for merge
* mtl : clean-up ggml mtl interface + suport scratch / inplace
* mtl : remove temp / debug code
* metal : final refactoring and simplification
* Revert "ci : disable temporary"
This reverts commit 98c267fc77
.
* metal : add comments
* metal : clean-up stuff, fix typos
* readme : add Metal instructions
* readme : add example for main
490 lines
15 KiB
Metal
490 lines
15 KiB
Metal
#include <metal_stdlib>
|
|
|
|
using namespace metal;
|
|
|
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
|
|
#define QK4_0 32
|
|
#define QR4_0 2
|
|
typedef struct {
|
|
half d; // delta
|
|
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
|
} block_q4_0;
|
|
|
|
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
|
const int qk = QK4_0;
|
|
|
|
assert(k % qk == 0);
|
|
|
|
const int nb = k / qk;
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
const half d = x[i].d;
|
|
|
|
for (int j = 0; j < qk/2; ++j) {
|
|
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
|
const int x1 = (x[i].qs[j] >> 4) - 8;
|
|
|
|
y[i*qk + j + 0 ] = x0*d;
|
|
y[i*qk + j + qk/2] = x1*d;
|
|
}
|
|
}
|
|
}
|
|
|
|
kernel void kernel_add(
|
|
device const float * src0,
|
|
device const float * src1,
|
|
device float * dst,
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
dst[tpig] = src0[tpig] + src1[tpig];
|
|
}
|
|
|
|
kernel void kernel_mul(
|
|
device const float * src0,
|
|
device const float * src1,
|
|
device float * dst,
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
dst[tpig] = src0[tpig] * src1[tpig];
|
|
}
|
|
|
|
// assumption: src1 is a row
|
|
// broadcast src1 into src0
|
|
kernel void kernel_mul_row(
|
|
device const float * src0,
|
|
device const float * src1,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
dst[tpig] = src0[tpig] * src1[tpig % ne00];
|
|
}
|
|
|
|
kernel void kernel_scale(
|
|
device const float * src0,
|
|
device float * dst,
|
|
constant float & scale,
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
dst[tpig] = src0[tpig] * scale;
|
|
}
|
|
|
|
kernel void kernel_silu(
|
|
device const float * src0,
|
|
device float * dst,
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
float x = src0[tpig];
|
|
dst[tpig] = x / (1.0f + exp(-x));
|
|
}
|
|
|
|
kernel void kernel_relu(
|
|
device const float * src0,
|
|
device float * dst,
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
dst[tpig] = max(0.0f, src0[tpig]);
|
|
}
|
|
|
|
kernel void kernel_soft_max(
|
|
device const float * src0,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne02,
|
|
threadgroup float * buf [[threadgroup(0)]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
const int64_t i03 = tgpig[2];
|
|
const int64_t i02 = tgpig[1];
|
|
const int64_t i01 = tgpig[0];
|
|
|
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
// parallel max
|
|
buf[tpitg[0]] = -INFINITY;
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
|
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
|
|
}
|
|
|
|
// reduce
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
|
if (tpitg[0] < i) {
|
|
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// broadcast
|
|
if (tpitg[0] == 0) {
|
|
buf[0] = buf[0];
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
const float max = buf[0];
|
|
|
|
// parallel sum
|
|
buf[tpitg[0]] = 0.0f;
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
|
buf[tpitg[0]] += exp(psrc0[i00] - max);
|
|
}
|
|
|
|
// reduce
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
|
if (tpitg[0] < i) {
|
|
buf[tpitg[0]] += buf[tpitg[0] + i];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// broadcast
|
|
if (tpitg[0] == 0) {
|
|
buf[0] = buf[0];
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
const float sum = buf[0];
|
|
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
|
pdst[i00] = exp(psrc0[i00] - max) / sum;
|
|
}
|
|
}
|
|
|
|
kernel void kernel_diag_mask_inf(
|
|
device const float * src0,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne01,
|
|
constant int & n_past,
|
|
uint3 tpig[[thread_position_in_grid]]) {
|
|
const int64_t i02 = tpig[2];
|
|
const int64_t i01 = tpig[1];
|
|
const int64_t i00 = tpig[0];
|
|
|
|
if (i00 > n_past + i01) {
|
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
|
} else {
|
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_get_rows_q4_0(
|
|
device const void * src0,
|
|
device const int * src1,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb1,
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
const int i = tpig;
|
|
const int r = ((device int32_t *) src1)[i];
|
|
|
|
dequantize_row_q4_0(
|
|
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
|
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
|
}
|
|
|
|
kernel void kernel_rms_norm(
|
|
device const void * src0,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant uint64_t & nb01,
|
|
constant float & eps,
|
|
threadgroup float * sum [[threadgroup(0)]],
|
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
uint ntg[[threads_per_threadgroup]]) {
|
|
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
|
|
|
// parallel sum
|
|
sum[tpitg] = 0.0f;
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
sum[tpitg] += x[i00] * x[i00];
|
|
}
|
|
|
|
// reduce
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
|
if (tpitg < i) {
|
|
sum[tpitg] += sum[tpitg + i];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// broadcast
|
|
if (tpitg == 0) {
|
|
sum[0] /= ne00;
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
const float mean = sum[0];
|
|
const float scale = 1.0f/sqrt(mean + eps);
|
|
|
|
device float * y = dst + tgpig*ne00;
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
y[i00] = x[i00] * scale;
|
|
}
|
|
}
|
|
|
|
kernel void kernel_mul_mat_q4_0_f32(
|
|
device const void * src0,
|
|
device const float * src1,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne01,
|
|
constant uint64_t & nb00,
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb02,
|
|
constant int64_t & ne10,
|
|
constant int64_t & ne11,
|
|
constant uint64_t & nb10,
|
|
constant uint64_t & nb11,
|
|
constant uint64_t & nb12,
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne1,
|
|
threadgroup float * sum [[threadgroup(0)]],
|
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
|
uint2 tpig[[thread_position_in_grid]],
|
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
|
uint2 tptg[[threads_per_threadgroup]]) {
|
|
const int nb = ne00/QK4_0;
|
|
|
|
const int64_t r0 = tgpig.x;
|
|
const int64_t r1 = tgpig.y;
|
|
|
|
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
|
|
device const float * y = (device const float *) src1 + r1*ne10;
|
|
|
|
const uint nth = tptg.x*tptg.y;
|
|
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
|
|
|
sum[ith] = 0.0f;
|
|
|
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
|
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
|
|
device const float4 * y0p = (device const float4 *) (y + i*QK4_0);
|
|
|
|
const float d = (float)((x + i)->d);
|
|
|
|
const uchar4 x0v = *(x0p + tpitg.y);
|
|
const float4 y0v = *(y0p + tpitg.y + 0);
|
|
const float4 y1v = *(y0p + tpitg.y + 4);
|
|
|
|
float acc = 0.0f;
|
|
|
|
for (int j = 0; j < 4; ++j) {
|
|
const int x0 = x0v[j] & 0x0F;
|
|
const int x1 = x0v[j] >> 4;
|
|
|
|
const float y0 = y0v[j];
|
|
const float y1 = y1v[j];
|
|
|
|
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
|
}
|
|
|
|
sum[ith] += acc*d;
|
|
}
|
|
|
|
// accumulate the sum from all threads in the threadgroup
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint i = nth/2; i > 0; i /= 2) {
|
|
if (ith < i) {
|
|
sum[ith] += sum[ith + i];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
if (ith == 0) {
|
|
dst[r1*ne0 + r0] = sum[0];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_mul_mat_f16_f32(
|
|
device const char * src0,
|
|
device const char * src1,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne01,
|
|
constant uint64_t & nb00,
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb02,
|
|
constant int64_t & ne10,
|
|
constant int64_t & ne11,
|
|
constant uint64_t & nb10,
|
|
constant uint64_t & nb11,
|
|
constant uint64_t & nb12,
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne1,
|
|
threadgroup float * sum [[threadgroup(0)]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tpig[[thread_position_in_grid]],
|
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
uint3 tptg[[threads_per_threadgroup]]) {
|
|
const int64_t r0 = tgpig.x;
|
|
const int64_t r1 = tgpig.y;
|
|
const int64_t im = tgpig.z;
|
|
|
|
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
|
|
sum[tpitg.x] = 0.0f;
|
|
|
|
for (int i = tpitg.x; i < ne00; i += tptg.x) {
|
|
sum[tpitg.x] += (float) x[i] * (float) y[i];
|
|
}
|
|
|
|
// accumulate the sum from all threads in the threadgroup
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint i = tptg.x/2; i > 0; i /= 2) {
|
|
if (tpitg.x < i) {
|
|
sum[tpitg.x] += sum[tpitg.x + i];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
if (tpitg.x == 0) {
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_rope(
|
|
device const void * src0,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne03,
|
|
constant uint64_t & nb00,
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb02,
|
|
constant uint64_t & nb03,
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne2,
|
|
constant int64_t & ne3,
|
|
constant uint64_t & nb0,
|
|
constant uint64_t & nb1,
|
|
constant uint64_t & nb2,
|
|
constant uint64_t & nb3,
|
|
constant int & n_past,
|
|
constant int & n_dims,
|
|
constant int & mode,
|
|
uint3 tpig[[thread_position_in_grid]]) {
|
|
const int64_t i3 = tpig[2];
|
|
const int64_t i2 = tpig[1];
|
|
const int64_t i1 = tpig[0];
|
|
|
|
const bool is_neox = mode & 2;
|
|
const float theta_scale = pow(10000.0, -2.0f/n_dims);
|
|
|
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
|
|
|
float theta = (float)p;
|
|
|
|
if (!is_neox) {
|
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
const float cos_theta = cos(theta);
|
|
const float sin_theta = sin(theta);
|
|
|
|
theta *= theta_scale;
|
|
|
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
const float x0 = src[0];
|
|
const float x1 = src[1];
|
|
|
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
}
|
|
} else {
|
|
// TODO: implement
|
|
}
|
|
}
|
|
|
|
kernel void kernel_cpy_f32_f16(
|
|
device const float * src0,
|
|
device half * dst,
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne03,
|
|
constant uint64_t & nb00,
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb02,
|
|
constant uint64_t & nb03,
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne2,
|
|
constant int64_t & ne3,
|
|
constant uint64_t & nb0,
|
|
constant uint64_t & nb1,
|
|
constant uint64_t & nb2,
|
|
constant uint64_t & nb3,
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
const int64_t i03 = tgpig[2];
|
|
const int64_t i02 = tgpig[1];
|
|
const int64_t i01 = tgpig[0];
|
|
|
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
|
dst_data[i00] = src[0];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_cpy_f32_f32(
|
|
device const float * src0,
|
|
device float * dst,
|
|
constant int64_t & ne00,
|
|
constant int64_t & ne01,
|
|
constant int64_t & ne02,
|
|
constant int64_t & ne03,
|
|
constant uint64_t & nb00,
|
|
constant uint64_t & nb01,
|
|
constant uint64_t & nb02,
|
|
constant uint64_t & nb03,
|
|
constant int64_t & ne0,
|
|
constant int64_t & ne1,
|
|
constant int64_t & ne2,
|
|
constant int64_t & ne3,
|
|
constant uint64_t & nb0,
|
|
constant uint64_t & nb1,
|
|
constant uint64_t & nb2,
|
|
constant uint64_t & nb3,
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
const int64_t i03 = tgpig[2];
|
|
const int64_t i02 = tgpig[1];
|
|
const int64_t i01 = tgpig[0];
|
|
|
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
|
dst_data[i00] = src[0];
|
|
}
|
|
}
|