mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-25 19:08:44 +01:00
68 lines
2.2 KiB
Plaintext
68 lines
2.2 KiB
Plaintext
|
#include "common.comp"
|
||
|
|
||
|
// TODO: use a local size of 32 or more (Metal uses 1024)
|
||
|
layout(local_size_x = 1) in;
|
||
|
|
||
|
layout (push_constant) uniform parameter {
|
||
|
uint inAOff;
|
||
|
uint inBOff;
|
||
|
uint outOff;
|
||
|
int n_dims;
|
||
|
int mode;
|
||
|
int n_orig_ctx;
|
||
|
float freq_base;
|
||
|
float freq_scale;
|
||
|
float ext_factor;
|
||
|
float attn_factor;
|
||
|
float beta_fast;
|
||
|
float beta_slow;
|
||
|
uint nb00;
|
||
|
uint nb01;
|
||
|
uint nb02;
|
||
|
uint nb03;
|
||
|
int ne0;
|
||
|
uint nb0;
|
||
|
uint nb1;
|
||
|
uint nb2;
|
||
|
uint nb3;
|
||
|
} pcs;
|
||
|
|
||
|
float rope_yarn_ramp(const float low, const float high, const float i0) {
|
||
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||
|
}
|
||
|
|
||
|
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||
|
void rope_yarn(
|
||
|
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
|
||
|
out float cos_theta, out float sin_theta
|
||
|
) {
|
||
|
// Get n-d rotational scaling corrected for extrapolation
|
||
|
float theta_interp = freq_scale * theta_extrap;
|
||
|
float theta = theta_interp;
|
||
|
if (ext_factor != 0.0f) {
|
||
|
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||
|
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||
|
|
||
|
// Get n-d magnitude scaling corrected for interpolation
|
||
|
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
||
|
}
|
||
|
cos_theta = cos(theta) * mscale;
|
||
|
sin_theta = sin(theta) * mscale;
|
||
|
}
|
||
|
|
||
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||
|
float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||
|
return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
|
||
|
}
|
||
|
|
||
|
void rope_yarn_corr_dims(
|
||
|
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
|
||
|
) {
|
||
|
// start and end correction dims
|
||
|
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||
|
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||
|
}
|