mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 12:57:03 +01:00
vulkan : assert various kernel requirements
This commit is contained in:
parent
f194e1b6a6
commit
a934b2cb8a
@ -1416,27 +1416,34 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
const float scale = *(const float *) src1->data;
|
const float scale = *(const float *) src1->data;
|
||||||
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8, scale);
|
int64_t n = ggml_nelements(dst);
|
||||||
|
GGML_ASSERT(n % 8 == 0);
|
||||||
|
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, n/8, scale);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
{
|
||||||
case GGML_UNARY_OP_SILU:
|
int64_t n = ggml_nelements(dst);
|
||||||
{
|
GGML_ASSERT(n % 4 == 0);
|
||||||
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
} break;
|
case GGML_UNARY_OP_SILU:
|
||||||
case GGML_UNARY_OP_RELU:
|
{
|
||||||
{
|
ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
|
||||||
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4);
|
} break;
|
||||||
} break;
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_GELU:
|
{
|
||||||
{
|
ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
|
||||||
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8);
|
} break;
|
||||||
} break;
|
case GGML_UNARY_OP_GELU:
|
||||||
default:
|
{
|
||||||
{
|
GGML_ASSERT(n % 8 == 0);
|
||||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
|
||||||
GGML_ASSERT(false);
|
} break;
|
||||||
}
|
default:
|
||||||
|
{
|
||||||
|
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
@ -1455,6 +1462,8 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
|
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
|
||||||
|
Loading…
Reference in New Issue
Block a user