mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 04:47:04 +01:00
kompute : fix ggml_add kernel
This commit is contained in:
parent
610394fff8
commit
1453215165
@ -1467,7 +1467,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
// src1 is a row
|
||||
ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
|
||||
} else {
|
||||
|
@ -30,6 +30,7 @@ layout(push_constant) uniform PushConstants {
|
||||
int nb1;
|
||||
int nb2;
|
||||
int nb3;
|
||||
//int offs; // TODO: needed for GGML_OP_ACC, see metal code
|
||||
} pcs;
|
||||
|
||||
// general-purpose kernel for addition of two tensors
|
||||
@ -44,15 +45,14 @@ void main() {
|
||||
const uint i12 = i02 % pcs.ne12;
|
||||
const uint i11 = i01 % pcs.ne11;
|
||||
|
||||
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + gl_SubgroupInvocationID.x*pcs.nb00) / 4);
|
||||
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 + gl_SubgroupInvocationID.x*pcs.nb10) / 4);
|
||||
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + gl_SubgroupInvocationID.x*pcs.nb0 ) / 4);
|
||||
int offs = 0; // TMP (see above)
|
||||
|
||||
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
|
||||
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4);
|
||||
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4);
|
||||
|
||||
for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
|
||||
out_[pcs.outOff + dst_off] = inA[pcs.inAOff + src0_off] + inB[pcs.inBOff + src1_off];
|
||||
|
||||
src0_off += gl_WorkGroupSize.x*pcs.ne00;
|
||||
src1_off += gl_WorkGroupSize.x*pcs.ne10;
|
||||
dst_off += gl_WorkGroupSize.x*pcs.ne0;
|
||||
const uint i10 = i0 % pcs.ne10;
|
||||
out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user