mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 14:50:51 +01:00
Scale the workgroup count down to allow correct generation for falcon with
AMD radeon cards with lower workgroup count limit Partially fixes #1581
This commit is contained in:
parent
89b71278ff
commit
e006d377dd
@ -1356,7 +1356,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
{
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
|
||||
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
|
||||
} else {
|
||||
ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
|
||||
}
|
||||
@ -1365,7 +1365,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
{
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00);
|
||||
ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
|
||||
} else {
|
||||
ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4);
|
||||
}
|
||||
@ -1373,7 +1373,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
const float scale = *(const float *) src1->data;
|
||||
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
|
||||
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8, scale);
|
||||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
@ -1387,7 +1387,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
{
|
||||
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
|
||||
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
|
||||
}
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
const uint baseIndex = gl_WorkGroupID.x * 8;
|
||||
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
for (uint x = 0; x < 8; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
const float y = in_[i + pcs.inOff];
|
||||
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y)));
|
||||
|
@ -24,7 +24,10 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i % pcs.row) + pcs.inBOff];
|
||||
}
|
||||
}
|
@ -22,7 +22,10 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x;
|
||||
const uint baseIndex = gl_WorkGroupID.x * 8;
|
||||
|
||||
for (uint x = 0; x < 8; x++) {
|
||||
const uint i = baseIndex + x;
|
||||
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
|
||||
}
|
||||
}
|
@ -18,8 +18,8 @@ layout(push_constant) uniform PushConstants {
|
||||
uint inOff;
|
||||
uint outOff;
|
||||
} pcs;
|
||||
void main() {
|
||||
|
||||
void main() {
|
||||
const uint baseIndex = gl_WorkGroupID.x * 4;
|
||||
|
||||
for (uint x = 0; x < 4; x++) {
|
||||
|
@ -387,6 +387,10 @@ Algorithm::recordDispatch(const vk::CommandBuffer& commandBuffer)
|
||||
void
|
||||
Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize)
|
||||
{
|
||||
if (workgroup[0] > 65535) {
|
||||
fprintf(stderr, "workgroup size is %d\n", workgroup[0]);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
KP_LOG_INFO("Kompute OpAlgoCreate setting dispatch size");
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user