mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 12:57:03 +01:00
Lower the workgroup count for some shaders by providing a loop that processes
four floats at a time.
This commit is contained in:
parent
752f7ebd61
commit
8d9efbf97a
@ -1358,7 +1358,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
// src1 is a row
|
||||
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
|
||||
} else {
|
||||
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst));
|
||||
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
@ -1367,7 +1367,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
// src1 is a row
|
||||
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
|
||||
} else {
|
||||
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst));
|
||||
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
@ -1379,15 +1379,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
{
|
||||
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
|
||||
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
|
||||
} break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
{
|
||||
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
|
||||
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
{
|
||||
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst));
|
||||
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i) + pcs.inBOff];
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[i + pcs.inBOff];
|
||||
}
|
||||
}
|
@ -20,8 +20,11 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const float x = in_[i + pcs.inOff];
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
out_[i + pcs.outOff] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
const float y = in_[i + pcs.inOff];
|
||||
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff];
|
||||
}
|
||||
}
|
@ -20,7 +20,10 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
|
||||
}
|
||||
}
|
||||
|
@ -19,8 +19,12 @@ layout(push_constant) uniform PushConstants {
|
||||
uint outOff;
|
||||
} pcs;
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const float x = in_[i + pcs.inOff];
|
||||
|
||||
out_[i + pcs.outOff] = x / (1.0 + exp(-x));
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
const float y = in_[i + pcs.inOff];
|
||||
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user