metal : simplify soft_max encoding

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-11-29 17:30:19 +02:00
parent 390a445906
commit 580fe2064c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 2 additions and 7 deletions

View File

@ -1040,12 +1040,7 @@ void ggml_metal_graph_compute(
const float scale = ((float *) dst->op_params)[0]; const float scale = ((float *) dst->op_params)[0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
} else {
[encoder setBuffer:nil offset:0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];

View File

@ -3705,8 +3705,8 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
// TODO: !!!!!!!!!
if (max_alibi_bias > 0.0f) { if (max_alibi_bias > 0.0f) {
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
kq = ggml_scale(ctx, kq, kq_scale); kq = ggml_scale(ctx, kq, kq_scale);
cb(kq, "kq_scaled", il); cb(kq, "kq_scaled", il);