/**
 * Copyright (c) 2023 Nomic, Inc. All rights reserved.
 *
 * This software is licensed under the terms of the Software for Open Models
 * License (SOM), version 1.0, as detailed in the LICENSE_SOM.txt file. A copy
 * of this license should accompany this software. Except as expressly granted
 * in the SOM license, all rights are reserved by Nomic, Inc.
 */

#version 450

#include "common.comp"

#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable

// device subgroup size
layout (local_size_x_id = 0) in;

layout(binding = 0) readonly buffer tensorInA { float inA[]; };
layout(binding = 1) readonly buffer tensorInB { float inB[]; };
layout(binding = 2) writeonly buffer tensorOut { float out_[]; };

layout(push_constant) uniform parameter {
  uint inAOff;
  uint inBOff;
  uint outOff;
  int ne00;
  int ne01;
  int ne02;
  int ne11;
  int ne12;
  uint nb01;
  uint nb02;
  uint nb11;
  uint nb12;
  uint nb1;
  uint nb2;
}
pcs;


void main() {
  uvec3 gid = gl_WorkGroupID;

  uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
  uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;

  const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
  const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
  float sum = 0.0f;
  for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
      sum += float(inA[x+i]) * float(inB[y+i]);
  }

  const float all_sum = subgroupAdd(sum);
  if (subgroupElect()) {
    out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
  }
}