Scale the workgroup count down to allow correct generation for falcon with

AMD radeon cards with lower workgroup count limit

Partially fixes #1581
This commit is contained in:
Adam Treat 2023-10-27 18:32:29 -04:00 committed by cebtenzzre
parent 89b71278ff
commit e006d377dd
7 changed files with 27 additions and 14 deletions

View File

@ -1356,7 +1356,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
if (ggml_nelements(src1) == ne10) {
// src1 is a row
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
} else {
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
}
@ -1365,7 +1365,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
if (ggml_nelements(src1) == ne10) {
// src1 is a row
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
} else {
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
}
@ -1373,7 +1373,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
case GGML_OP_SCALE:
{
const float scale = *(const float *) src1->data;
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8, scale);
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
@ -1387,7 +1387,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
} break;
case GGML_UNARY_OP_GELU:
{
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8);
} break;
default:
{

View File

@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
}
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
}
}

View File

@ -20,9 +20,9 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 4; x++) {
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));

View File

@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 4;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i % pcs.row) + pcs.inBOff];
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i % pcs.row) + pcs.inBOff];
}
}

View File

@ -22,7 +22,10 @@ layout(push_constant) uniform PushConstants {
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
const uint baseIndex = gl_WorkGroupID.x * 8;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
}

View File

@ -18,8 +18,8 @@ layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
} pcs;
void main() {
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {

View File

@ -387,6 +387,10 @@ Algorithm::recordDispatch(const vk::CommandBuffer& commandBuffer)
void
Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize)
{
if (workgroup[0] > 65535) {
fprintf(stderr, "workgroup size is %d\n", workgroup[0]);
fflush(stderr);
}
KP_LOG_INFO("Kompute OpAlgoCreate setting dispatch size");