mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
metal : fix kernel_norm (fixes Falcon on Metal) (#3057)
* metal : fix kernel_norm ggml-ci * metal : put warning in kernel_norm to not combine the loops * metal : restore original F16 mat-vec multiplication It works after the norm fixes * common : don't do warm-up with more than n_batch tokens (close #3058) ggml-ci * metal : minor
This commit is contained in:
parent
fec2fb19e4
commit
c4f496648c
@ -773,7 +773,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
LOG("warming up the model with an empty run\n");
|
LOG("warming up the model with an empty run\n");
|
||||||
|
|
||||||
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
||||||
llama_eval(lctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
|
||||||
llama_reset_timings(lctx);
|
llama_reset_timings(lctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,27 +220,32 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
// broadcast
|
||||||
//if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
// sum[0] /= ne00;
|
sum[0] /= ne00;
|
||||||
//}
|
}
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float mean = sum[0];
|
const float mean = sum[0];
|
||||||
|
|
||||||
// recenter and VARIANCE
|
// recenter
|
||||||
device float * y = dst + tgpig*ne00;
|
device float * y = dst + tgpig*ne00;
|
||||||
sum[tpitg] = 0.0f;
|
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
y[i00] = x[i00] - mean;
|
y[i00] = x[i00] - mean;
|
||||||
|
}
|
||||||
|
|
||||||
|
// VARIANCE
|
||||||
|
// parallel sum
|
||||||
|
//
|
||||||
|
// WARNING: combining this loop with the one above will give you wrong results for nth == 256
|
||||||
|
// I have no idea why, so for now I am keeping them separate. But this behavior is very concerning.
|
||||||
|
// Tested with:
|
||||||
|
// ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4
|
||||||
|
//
|
||||||
|
sum[tpitg] = 0.0f;
|
||||||
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
sum[tpitg] += y[i00] * y[i00];
|
sum[tpitg] += y[i00] * y[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
//// VARIANCE
|
|
||||||
//// parallel sum
|
|
||||||
//sum[tpitg] = 0.0f;
|
|
||||||
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
||||||
// sum[tpitg] += y[i00] * y[i00];
|
|
||||||
//}
|
|
||||||
// reduce
|
// reduce
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||||
@ -249,11 +254,11 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
// broadcast
|
||||||
//if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
// sum[0] /= ne00;
|
sum[0] /= ne00;
|
||||||
//}
|
}
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float variance = sum[0];
|
const float variance = sum[0];
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + eps);
|
const float scale = 1.0f/sqrt(variance + eps);
|
||||||
@ -262,7 +267,6 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
kernel void kernel_rms_norm(
|
kernel void kernel_rms_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -630,7 +634,6 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
kernel void kernel_alibi_f32(
|
||||||
|
Loading…
Reference in New Issue
Block a user