mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 14:50:51 +01:00
use op param epsilon for norms
This commit is contained in:
parent
3327d84a7f
commit
d5741c07a5
@ -810,12 +810,10 @@ void ggml_vk_norm_(const std::vector<uint32_t>& spirv, kp::Sequence& seq,
|
|||||||
const std::shared_ptr<kp::Tensor>& out,
|
const std::shared_ptr<kp::Tensor>& out,
|
||||||
uint32_t inOff, uint32_t outOff,
|
uint32_t inOff, uint32_t outOff,
|
||||||
int32_t ne00, int32_t nb01,
|
int32_t ne00, int32_t nb01,
|
||||||
int32_t nrows) {
|
int32_t nrows, float epsilon) {
|
||||||
GGML_ASSERT(nb01%sizeof(float) == 0);
|
GGML_ASSERT(nb01%sizeof(float) == 0);
|
||||||
GGML_ASSERT(ne00%sizeof(float) == 0);
|
GGML_ASSERT(ne00%sizeof(float) == 0);
|
||||||
|
|
||||||
const float epsilon = 1e-6f; // this is what ggml.c uses for rms norm
|
|
||||||
|
|
||||||
struct PushConstants {
|
struct PushConstants {
|
||||||
uint32_t inOff, outOff;
|
uint32_t inOff, outOff;
|
||||||
uint32_t ne00, nb01;
|
uint32_t ne00, nb01;
|
||||||
@ -1559,11 +1557,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0));
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0));
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user