mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
parent
0eb332a10f
commit
2fffa0d61f
@ -4539,7 +4539,7 @@ static __global__ void rope(
|
|||||||
const int i2 = row/p_delta_rows;
|
const int i2 = row/p_delta_rows;
|
||||||
|
|
||||||
const int p = has_pos ? pos[i2] : 0;
|
const int p = has_pos ? pos[i2] : 0;
|
||||||
const float theta_base = p*powf(freq_base, -col/ncols);
|
const float theta_base = p*powf(freq_base, -float(col)/ncols);
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
@ -4566,8 +4566,8 @@ static __global__ void rope_neox(
|
|||||||
const int i = row*ncols + col/2;
|
const int i = row*ncols + col/2;
|
||||||
const int i2 = row/p_delta_rows;
|
const int i2 = row/p_delta_rows;
|
||||||
|
|
||||||
// simplified from `(row * ncols + col) * (-1 / ncols)`
|
// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
|
||||||
const float cur_rot = -col/ncols - row;
|
const float cur_rot = -float(col)/ncols;
|
||||||
|
|
||||||
const int p = has_pos ? pos[i2] : 0;
|
const int p = has_pos ? pos[i2] : 0;
|
||||||
const float theta_base = p*powf(freq_base, cur_rot);
|
const float theta_base = p*powf(freq_base, cur_rot);
|
||||||
|
Loading…
Reference in New Issue
Block a user