mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 22:08:55 +01:00
cuda : update supports_op for matrix multiplication (#8245)
This commit is contained in:
parent
a9f3b10215
commit
0e0590adab
@ -2711,27 +2711,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
struct ggml_tensor * a;
|
struct ggml_tensor * a = op->src[0];
|
||||||
struct ggml_tensor * b;
|
|
||||||
if (op->op == GGML_OP_MUL_MAT) {
|
if (op->op == GGML_OP_MUL_MAT) {
|
||||||
a = op->src[0];
|
struct ggml_tensor * b = op->src[1];
|
||||||
b = op->src[1];
|
if (a->ne[3] != b->ne[3]) {
|
||||||
} else {
|
|
||||||
a = op->src[2];
|
|
||||||
b = op->src[1];
|
|
||||||
}
|
|
||||||
if (a->ne[3] != b->ne[3]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
ggml_type a_type = a->type;
|
|
||||||
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
|
||||||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
|
|
||||||
a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
|
|
||||||
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
switch (a->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_Q8_K:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ2_S:
|
||||||
|
case GGML_TYPE_IQ2_XS:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
case GGML_TYPE_IQ3_S:
|
||||||
|
case GGML_TYPE_IQ3_XXS:
|
||||||
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
|
@ -2052,6 +2052,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
||||||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
|
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
|
||||||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
||||||
|
GGML_TYPE_BF16,
|
||||||
};
|
};
|
||||||
|
|
||||||
// unary ops
|
// unary ops
|
||||||
|
Loading…
Reference in New Issue
Block a user