mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
metal : simplify soft_max encoding
ggml-ci
This commit is contained in:
parent
390a445906
commit
580fe2064c
@ -1040,12 +1040,7 @@ void ggml_metal_graph_compute(
|
|||||||
const float scale = ((float *) dst->op_params)[0];
|
const float scale = ((float *) dst->op_params)[0];
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
if (id_src1) {
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
||||||
} else {
|
|
||||||
[encoder setBuffer:nil offset:0 atIndex:1];
|
|
||||||
}
|
|
||||||
|
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
|
@ -3705,8 +3705,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
// TODO: !!!!!!!!!
|
|
||||||
if (max_alibi_bias > 0.0f) {
|
if (max_alibi_bias > 0.0f) {
|
||||||
|
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
|
||||||
kq = ggml_scale(ctx, kq, kq_scale);
|
kq = ggml_scale(ctx, kq, kq_scale);
|
||||||
cb(kq, "kq_scaled", il);
|
cb(kq, "kq_scaled", il);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user