mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
sycl : Fixes to broken builds and test-backend-ops (#10257)
* Fixes broken build for the SYCL CUDA backend caused by non-explicit gemm call in outprod (merged in with RWKV6 in Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration #10133) * Marks permuted MUL_MAT as unsupported to be able to run test-backend-ops * Fixes asserts in norm to fix debug builds.
This commit is contained in:
parent
80dd7ff22f
commit
2e82ffa4af
@ -4350,6 +4350,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
if (op->op == GGML_OP_MUL_MAT) {
|
if (op->op == GGML_OP_MUL_MAT) {
|
||||||
a = op->src[0];
|
a = op->src[0];
|
||||||
b = op->src[1];
|
b = op->src[1];
|
||||||
|
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
|
||||||
|
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
a = op->src[2];
|
a = op->src[2];
|
||||||
b = op->src[1];
|
b = op->src[1];
|
||||||
|
@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
|
|||||||
|
|
||||||
const int nthreads = item_ct1.get_local_range(2);
|
const int nthreads = item_ct1.get_local_range(2);
|
||||||
const int nwarps = nthreads / WARP_SIZE;
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
assert(nwarps % WARP_SIZE == 0);
|
|
||||||
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
|||||||
int end = start + group_size;
|
int end = start + group_size;
|
||||||
const int nthreads = item_ct1.get_local_range(2);
|
const int nthreads = item_ct1.get_local_range(2);
|
||||||
const int nwarps = nthreads / WARP_SIZE;
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
assert(nwarps % WARP_SIZE == 0);
|
|
||||||
start += item_ct1.get_local_id(2);
|
start += item_ct1.get_local_id(2);
|
||||||
int nreduce = nwarps / WARP_SIZE;
|
int nreduce = nwarps / WARP_SIZE;
|
||||||
|
|
||||||
@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|||||||
const int tid = item_ct1.get_local_id(2);
|
const int tid = item_ct1.get_local_id(2);
|
||||||
const int nthreads = item_ct1.get_local_range(2);
|
const int nthreads = item_ct1.get_local_range(2);
|
||||||
const int nwarps = nthreads / WARP_SIZE;
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
assert(nwarps % WARP_SIZE == 0);
|
|
||||||
float tmp = 0.0f; // partial sum for thread in warp
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||||
|
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||||
/*
|
/*
|
||||||
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
|
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
|
||||||
@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||||
|
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||||
/*
|
/*
|
||||||
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
|
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
|
||||||
@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||||
|
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||||
/*
|
/*
|
||||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include <sycl/sycl.hpp>
|
#include <sycl/sycl.hpp>
|
||||||
|
#include <oneapi/mkl.hpp>
|
||||||
#include "outprod.hpp"
|
#include "outprod.hpp"
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// Perform matrix multiplication using oneMKL GEMM
|
// Perform matrix multiplication using oneMKL GEMM
|
||||||
oneapi::mkl::blas::gemm(*stream,
|
oneapi::mkl::blas::column_major::gemm(*stream,
|
||||||
oneapi::mkl::transpose::nontrans, src1_op,
|
oneapi::mkl::transpose::nontrans, src1_op,
|
||||||
ne0, ne1, ne01,
|
ne0, ne1, ne01,
|
||||||
alpha,
|
alpha,
|
||||||
|
Loading…
Reference in New Issue
Block a user