mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
cuda : fix rope + add tests (#7452)
* cuda : fix rope pos data ggml-ci * ggml : drop mode & 1 == 1 support for ggml_rope ggml-ci * ggml : support freq_factors for f16 rope (CPU) ggml-ci * tests : add rope tests using frequency factors ggml-ci
This commit is contained in:
parent
201cc11afa
commit
3e5faa8503
@ -283,9 +283,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
if (is_neox) {
|
||||
pos = (const int32_t *) src1_d;
|
||||
pos = (const int32_t *) src1_d;
|
||||
|
||||
if (is_neox) {
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
20
ggml.c
20
ggml.c
@ -6245,6 +6245,8 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
float xpos_base,
|
||||
bool xpos_down,
|
||||
bool inplace) {
|
||||
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
||||
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
@ -14413,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
@ -14529,6 +14531,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: deduplicate f16/f32 code
|
||||
static void ggml_compute_forward_rope_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
@ -14536,6 +14539,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
@ -14588,6 +14592,17 @@ static void ggml_compute_forward_rope_f16(
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
||||
// this essentially just switches the sign of sin.
|
||||
@ -14660,10 +14675,11 @@ static void ggml_compute_forward_rope_f16(
|
||||
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
sin_theta *= sin_sign;
|
||||
|
2
ggml.h
2
ggml.h
@ -1460,7 +1460,7 @@ extern "C" {
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// rotary position embedding
|
||||
// if mode & 1 == 1, skip n_past elements (DEPRECATED)
|
||||
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
||||
// if mode & 2 == 1, GPT-NeoX style
|
||||
// if mode & 4 == 1, ChatGLM style
|
||||
//
|
||||
|
@ -1142,20 +1142,22 @@ struct test_rope : public test_case {
|
||||
int n_dims;
|
||||
int mode;
|
||||
int n_ctx;
|
||||
bool ff;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
|
||||
return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
|
||||
}
|
||||
|
||||
test_rope(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
||||
int n_dims = 10, int mode = 0, int n_ctx = 512)
|
||||
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
|
||||
int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
|
||||
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
||||
ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
|
||||
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
||||
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -1169,7 +1171,12 @@ struct test_rope : public test_case {
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
if (t->ne[0] == n_dims/2) {
|
||||
// frequency factors in the range [0.9f, 1.1f]
|
||||
init_tensor_uniform(t, 0.9f, 1.1f);
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2188,16 +2195,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512)); // llama 30B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512)); // llama 65B
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
|
||||
// TODO: ff not supported yet for !neox
|
||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
|
||||
|
||||
for (bool ff : {false, true}) { // freq_factors
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
|
||||
|
Loading…
x
Reference in New Issue
Block a user