Scale the workgroup count down to allow correct generation for falcon with

AMD radeon cards with lower workgroup count limit

Partially fixes #1581
This commit is contained in:
Adam Treat 2023-10-27 18:32:29 -04:00 committed by cebtenzzre
parent 89b71278ff
commit e006d377dd
7 changed files with 27 additions and 14 deletions

View File

@ -1356,7 +1356,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{ {
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
} else { } else {
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
} }
@ -1365,7 +1365,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{ {
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
} else { } else {
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
} }
@ -1373,7 +1373,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
const float scale = *(const float *) src1->data; const float scale = *(const float *) src1->data;
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale); ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8, scale);
} break; } break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) { switch (ggml_get_unary_op(gf->nodes[i])) {
@ -1387,7 +1387,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
} break; } break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8);
} break; } break;
default: default:
{ {

View File

@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff]; for (uint x = 0; x < 4; x++) {
} const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
}
}

View File

@ -20,9 +20,9 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint baseIndex = gl_WorkGroupID.x * 4; const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 4; x++) { for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x; const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff]; const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y))); out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));

View File

@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i % pcs.row) + pcs.inBOff]; for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i % pcs.row) + pcs.inBOff];
}
} }

View File

@ -22,7 +22,10 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
const uint i = gl_WorkGroupID.x; const uint baseIndex = gl_WorkGroupID.x * 8;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
} }

View File

@ -18,8 +18,8 @@ layout(push_constant) uniform PushConstants {
uint inOff; uint inOff;
uint outOff; uint outOff;
} pcs; } pcs;
void main() {
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4; const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) { for (uint x = 0; x < 4; x++) {

View File

@ -387,6 +387,10 @@ Algorithm::recordDispatch(const vk::CommandBuffer& commandBuffer)
void void
Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize) Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize)
{ {
if (workgroup[0] > 65535) {
fprintf(stderr, "workgroup size is %d\n", workgroup[0]);
fflush(stderr);
}
KP_LOG_INFO("Kompute OpAlgoCreate setting dispatch size"); KP_LOG_INFO("Kompute OpAlgoCreate setting dispatch size");