diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c57a031e4..15fc6154f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -9064,7 +9064,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float));