mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
cuda : support non-contiguous src1 in get_rows
This commit is contained in:
parent
2e4db48291
commit
62b95f93d0
132
ggml-cuda.cu
132
ggml-cuda.cu
@ -1686,31 +1686,39 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
|
||||
const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
static __global__ void k_get_rows(
|
||||
const void * src0, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
|
||||
|
||||
if (col >= ncols) {
|
||||
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int r = y[row];
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
// copy x[r*ncols + col] to dst[row*ncols + col]
|
||||
const int xi = r*ncols + col;
|
||||
const int di = row*ncols + col;
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int ib = xi/qk; // block index
|
||||
const int iqs = (xi%qk)/qr; // quant index
|
||||
const int iybs = di - di%qk; // y block start index
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel(x, ib, iqs, v);
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
|
||||
dst[iybs + iqs + 0] = v.x;
|
||||
dst[iybs + iqs + y_offset] = v.y;
|
||||
dst_row[iybs + iqs + 0] = v.x;
|
||||
dst_row[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
@ -5055,11 +5063,35 @@ static __global__ void im2col_f32_f16(
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, nrows, 1);
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
|
||||
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
@ -5071,7 +5103,6 @@ struct bin_bcast_cuda {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
|
||||
int nr0 = ne10/ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
@ -5119,26 +5150,28 @@ struct bin_bcast_cuda {
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
|
||||
//size_t nb0 = cnb0[0];
|
||||
size_t nb0 = cnb0[0];
|
||||
size_t nb1 = cnb0[1];
|
||||
size_t nb2 = cnb0[2];
|
||||
size_t nb3 = cnb0[3];
|
||||
|
||||
//size_t nb10 = cnb1[0];
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
//size_t s0 = nb0 / sizeof(src1_t);
|
||||
size_t s0 = nb0 / sizeof(src1_t);
|
||||
size_t s1 = nb1 / sizeof(src1_t);
|
||||
size_t s2 = nb2 / sizeof(src1_t);
|
||||
size_t s3 = nb3 / sizeof(src1_t);
|
||||
|
||||
//size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
@ -6449,36 +6482,34 @@ static void ggml_cuda_op_get_rows(
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int ncols = src0->ne[0];
|
||||
const int nrows = ggml_nelements(src1);
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
||||
|
||||
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<1, 1, convert_f16>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<1, 1, convert_f32>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
@ -8286,11 +8317,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
||||
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
if (src1->backend == GGML_BACKEND_GPU) {
|
||||
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
|
||||
} else {
|
||||
src1_row.data = (char *) src1->data + i01*src1->nb[1];
|
||||
}
|
||||
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
|
||||
src1_row.data = (char *) src1->data + i01*src1->nb[1];
|
||||
|
||||
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
|
||||
|
||||
@ -8707,9 +8735,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
func = ggml_cuda_repeat;
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
if (ggml_is_contiguous(tensor->src[1])) {
|
||||
func = ggml_cuda_get_rows;
|
||||
}
|
||||
func = ggml_cuda_get_rows;
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
func = ggml_cuda_dup;
|
||||
@ -9215,6 +9241,21 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
@ -9222,7 +9263,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MUL:
|
||||
@ -9298,7 +9338,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
|
||||
UNUSED(params);
|
||||
}
|
||||
|
||||
extern "C" int ggml_backend_cuda_reg_devices() {
|
||||
extern "C" int ggml_backend_cuda_reg_devices();
|
||||
|
||||
int ggml_backend_cuda_reg_devices() {
|
||||
int device_count = ggml_cuda_get_device_count();
|
||||
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
|
||||
for (int i = 0; i < device_count; i++) {
|
||||
|
26
llama.cpp
26
llama.cpp
@ -4254,12 +4254,13 @@ struct llm_build_context {
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
|
||||
cb(weights, "ffn_moe_weights", il);
|
||||
|
||||
weights = ggml_reshape_2d(ctx0, weights,
|
||||
n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||
weights = ggml_reshape_2d(ctx0, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||
@ -4268,7 +4269,7 @@ struct llm_build_context {
|
||||
cb(weights, "ffn_moe_weights_norm", il);
|
||||
|
||||
// compute expert outputs
|
||||
ggml_tensor * moe_out;
|
||||
ggml_tensor * moe_out = nullptr;
|
||||
|
||||
for (int i = 0; i < n_experts_per_tok; ++i) {
|
||||
ggml_tensor * cur_expert;
|
||||
@ -4279,19 +4280,19 @@ struct llm_build_context {
|
||||
ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
|
||||
|
||||
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur);
|
||||
cb(cur_up, "ffn_up", il);
|
||||
cb(cur_up, "ffn_moe_up", il);
|
||||
|
||||
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur);
|
||||
cb(cur_gate, "ffn_gate", il);
|
||||
cb(cur_gate, "ffn_moe_gate", il);
|
||||
|
||||
cur_gate = ggml_silu(ctx0, cur_gate);
|
||||
cb(cur_gate, "ffn_silu", il);
|
||||
cb(cur_gate, "ffn_moe_silu", il);
|
||||
|
||||
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
|
||||
cb(cur_expert, "ffn_gate_par", il);
|
||||
cb(cur_expert, "ffn_moe_gate_par", il);
|
||||
|
||||
cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
||||
cb(cur_expert, "ffn_down", il);
|
||||
cb(cur_expert, "ffn_moe_down", il);
|
||||
|
||||
cur_expert = ggml_mul(ctx0, cur_expert,
|
||||
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
||||
@ -5562,10 +5563,15 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
||||
|
||||
{ "ffn_moe_logits", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_probs", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weights", OFFLOAD_FUNC_NOP },
|
||||
{ "ffn_moe_argsort", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weights", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weights_sum", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weights_norm", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weighted", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_up", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_gate", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_gate_par", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_down", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_out", OFFLOAD_FUNC },
|
||||
|
||||
{ "l_out", OFFLOAD_FUNC },
|
||||
|
@ -489,17 +489,21 @@ struct test_get_rows : public test_case {
|
||||
const int m; // rows
|
||||
const int r; // rows to get
|
||||
const int b; // batch size
|
||||
const bool v; // view (non-contiguous src1)
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, n, m, r);
|
||||
return VARS_TO_STR6(type, n, m, r, b, v);
|
||||
}
|
||||
|
||||
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1)
|
||||
: type(type), n(n), m(m), r(r), b(b) {}
|
||||
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
|
||||
: type(type), n(n), m(m), r(r), b(b), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
|
||||
ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
|
||||
if (v) {
|
||||
rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
|
||||
}
|
||||
ggml_tensor * out = ggml_get_rows(ctx, in, rows);
|
||||
return out;
|
||||
}
|
||||
@ -507,6 +511,7 @@ struct test_get_rows : public test_case {
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
// rows
|
||||
std::vector<int> data(r*b);
|
||||
for (int i = 0; i < r*b; i++) {
|
||||
@ -773,9 +778,10 @@ struct test_mul_mat_id : public test_case {
|
||||
const int64_t m;
|
||||
const int64_t n;
|
||||
const int64_t k;
|
||||
const bool v; // view (non-contiguous ids)
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR7(type_a, type_b, n_mats, id, m, n, k);
|
||||
return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
@ -793,9 +799,9 @@ struct test_mul_mat_id : public test_case {
|
||||
|
||||
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
||||
int n_mats = 2, int id = 0,
|
||||
int64_t m = 32, int64_t n = 32, int64_t k = 32)
|
||||
int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
|
||||
: type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
|
||||
m(m), n(n), k(k) {}
|
||||
m(m), n(n), k(k), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
|
||||
@ -805,8 +811,11 @@ struct test_mul_mat_id : public test_case {
|
||||
mats.push_back(a);
|
||||
}
|
||||
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
|
||||
if (v) {
|
||||
ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
|
||||
}
|
||||
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
|
||||
ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, id, b);
|
||||
ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -815,11 +824,12 @@ struct test_mul_mat_id : public test_case {
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i;
|
||||
data[i] = i % n_mats;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
@ -1120,14 +1130,27 @@ enum test_mode {
|
||||
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
|
||||
std::vector<std::unique_ptr<test_case>> test_cases;
|
||||
|
||||
const ggml_type all_types[] = {
|
||||
GGML_TYPE_F32, GGML_TYPE_F16,
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K
|
||||
};
|
||||
|
||||
// unary ops
|
||||
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
|
||||
}
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3, 7));
|
||||
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3, 7));
|
||||
for (ggml_type type : all_types) {
|
||||
for (int b : {1, 7}) {
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
|
||||
@ -1183,16 +1206,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
|
||||
}
|
||||
|
||||
const ggml_type all_types[] = {
|
||||
GGML_TYPE_F32, GGML_TYPE_F16,
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K
|
||||
};
|
||||
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||
// FIXME: CPU crashes on f16xf16
|
||||
@ -1216,9 +1229,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||
for (int n_mats : {1, 2, 4}) {
|
||||
for (int n_mats : {2, 4, 8}) {
|
||||
for (int id = 0; id < n_mats; id++) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256));
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user