mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 20:43:07 +01:00
llama : switch to floating-point token positions
ggml-ci
This commit is contained in:
parent
15499eb942
commit
fc775366f1
@ -1015,9 +1015,9 @@ static struct ggml_tensor * forward_lora(
|
||||
struct ggml_tensor * kc = kv_self.k;
|
||||
struct ggml_tensor * vc = kv_self.v;
|
||||
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, N);
|
||||
{
|
||||
int * data = (int *) KQ_pos->data;
|
||||
float * data = (float *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
|
@ -554,7 +554,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||
};
|
||||
|
||||
// KQ_pos - contains the positions
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
|
||||
ggml_set_input(KQ_pos);
|
||||
|
||||
// rope has so much parameters that we make a custom function for it
|
||||
@ -743,7 +743,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||
|
||||
// set KQ_pos
|
||||
{
|
||||
int * data = (int *) KQ_pos->data;
|
||||
float * data = (float *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
|
@ -338,7 +338,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, (float) *n_past, 1, 0, };
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
@ -1281,7 +1281,7 @@ struct llama_server_context
|
||||
}
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
||||
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, (float) slot.n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch_img))
|
||||
{
|
||||
LOG_TEE("%s : failed to eval image\n", __func__);
|
||||
|
@ -291,7 +291,7 @@ static struct ggml_tensor * llama_build_train_graphs(
|
||||
};
|
||||
|
||||
// KQ_pos - contains the positions
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
|
||||
ggml_set_input(KQ_pos);
|
||||
|
||||
// rope has so much parameters that we make a custom function for it
|
||||
@ -419,7 +419,7 @@ static struct ggml_tensor * llama_build_train_graphs(
|
||||
ggml_gallocr_alloc_graph(alloc, gb);
|
||||
|
||||
if (!measure_only) {
|
||||
int * data = (int *) KQ_pos->data;
|
||||
float * data = (float *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
|
30
ggml-cuda.cu
30
ggml-cuda.cu
@ -6040,7 +6040,7 @@ static __device__ void rope_yarn(
|
||||
// rope == RoPE == rotary positional embedding
|
||||
template<typename T, bool has_pos>
|
||||
static __global__ void rope(
|
||||
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
const T * x, T * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
@ -6053,7 +6053,7 @@ static __global__ void rope(
|
||||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float p = has_pos ? pos[i2] : 0.0f;
|
||||
const float theta_base = p*powf(freq_base, -float(col)/ncols);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
@ -6068,7 +6068,7 @@ static __global__ void rope(
|
||||
|
||||
template<typename T, bool has_pos>
|
||||
static __global__ void rope_neox(
|
||||
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
const T * x, T * dst, int ncols, int n_dims, const float * pos, float freq_scale, int p_delta_rows,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
@ -6095,7 +6095,7 @@ static __global__ void rope_neox(
|
||||
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float p = has_pos ? pos[i2] : 0.0f;
|
||||
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
@ -6109,7 +6109,7 @@ static __global__ void rope_neox(
|
||||
}
|
||||
|
||||
static __global__ void rope_glm_f32(
|
||||
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
const float * x, float * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
int n_ctx
|
||||
) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
@ -6124,10 +6124,10 @@ static __global__ void rope_glm_f32(
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
|
||||
// FIXME: this is likely wrong
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
|
||||
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
|
||||
const float p = pos != nullptr ? pos[i2] : 0.0f;
|
||||
|
||||
const float theta = min(p, (float) n_ctx - 2)*freq_scale*col_theta_scale;
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
@ -6137,7 +6137,7 @@ static __global__ void rope_glm_f32(
|
||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
|
||||
const float block_theta = max(p - n_ctx - 2, 0.0f)*col_theta_scale;
|
||||
const float sin_block_theta = sinf(block_theta);
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
|
||||
@ -7688,7 +7688,7 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
|
||||
|
||||
template<typename T>
|
||||
static void rope_cuda(
|
||||
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
const T * x, T * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
@ -7708,7 +7708,7 @@ static void rope_cuda(
|
||||
|
||||
template<typename T>
|
||||
static void rope_neox_cuda(
|
||||
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
const T * x, T * dst, int ncols, int n_dims, int nrows, const float * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
@ -7733,7 +7733,7 @@ static void rope_neox_cuda(
|
||||
}
|
||||
|
||||
static void rope_glm_f32_cuda(
|
||||
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
const float * x, float * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, int n_ctx, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 4 == 0);
|
||||
@ -9035,11 +9035,11 @@ static void ggml_cuda_op_rope(
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const int32_t * pos = nullptr;
|
||||
const float * pos = nullptr;
|
||||
if ((mode & 1) == 0) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->ne[0] == ne2);
|
||||
pos = (const int32_t *) src1_dd;
|
||||
pos = (const float *) src1_dd;
|
||||
}
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
|
@ -2057,7 +2057,13 @@ static bool ggml_metal_graph_compute(
|
||||
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
|
@ -1674,7 +1674,7 @@ static void rope_yarn_corr_dims(
|
||||
|
||||
typedef void (rope_t)(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
@ -1709,7 +1709,7 @@ typedef void (rope_t)(
|
||||
template<typename T>
|
||||
kernel void kernel_rope(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
@ -1749,11 +1749,11 @@ kernel void kernel_rope(
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = src1;
|
||||
device const float * pos = src1;
|
||||
|
||||
const int64_t p = pos[i2];
|
||||
const float p = pos[i2];
|
||||
|
||||
const float theta_0 = (float)p;
|
||||
const float theta_0 = p;
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
if (!is_neox) {
|
||||
|
12
ggml.c
12
ggml.c
@ -5254,7 +5254,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
bool xpos_down,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
bool is_node = false;
|
||||
@ -5377,7 +5377,7 @@ struct ggml_tensor * ggml_rope_back(
|
||||
float xpos_base,
|
||||
bool xpos_down) {
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
@ -12352,11 +12352,11 @@ static void ggml_compute_forward_rope_f32(
|
||||
// this essentially just switches the sign of sin.
|
||||
const float sin_sign = forward ? 1.0f : -1.0f;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * pos = (const float *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
const float p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
@ -12523,11 +12523,11 @@ static void ggml_compute_forward_rope_f16(
|
||||
// this essentially just switches the sign of sin.
|
||||
const float sin_sign = forward ? 1.0f : -1.0f;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * pos = (const float *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
const float p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
|
20
llama.cpp
20
llama.cpp
@ -1699,8 +1699,8 @@ struct llama_layer {
|
||||
};
|
||||
|
||||
struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
float pos = -1.0f;
|
||||
float delta = 0.0f;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
@ -1939,10 +1939,10 @@ struct llama_context {
|
||||
ggml_context * ctx_input = nullptr;
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_pos; // F32 [n_batch]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
|
||||
struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
|
||||
struct ggml_tensor * inp_K_shift; // F32 [n_ctx]
|
||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
|
||||
@ -2222,7 +2222,7 @@ static void llama_kv_cache_seq_div(
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
float d) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
@ -7744,7 +7744,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
|
||||
float * data = (float *) lctx.inp_K_shift->data;
|
||||
|
||||
for (int i = 0; i < n_ctx; ++i) {
|
||||
data[i] = lctx.kv_self.cells[i].delta;
|
||||
@ -11690,10 +11690,10 @@ struct llama_context * llama_new_context_with_model(
|
||||
|
||||
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
|
||||
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch);
|
||||
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
|
||||
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
|
||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
|
||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
|
||||
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
||||
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
|
||||
@ -12046,7 +12046,7 @@ void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, l
|
||||
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, float d) {
|
||||
if (d == 1) {
|
||||
return;
|
||||
}
|
||||
@ -12461,7 +12461,7 @@ int llama_eval_embd(
|
||||
int32_t n_past) {
|
||||
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
|
||||
|
||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, (float) n_past, 1, 0, };
|
||||
|
||||
const int ret = llama_decode_internal(*ctx, batch);
|
||||
if (ret < 0) {
|
||||
|
4
llama.h
4
llama.h
@ -54,7 +54,7 @@ extern "C" {
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef float llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
typedef int32_t llama_seq_id;
|
||||
|
||||
@ -531,7 +531,7 @@ extern "C" {
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
float d);
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
|
@ -1134,14 +1134,15 @@ struct test_rope : public test_case {
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]);
|
||||
ggml_set_name(pos, "pos");
|
||||
ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (strcmp(ggml_get_name(t), "pos") == 0) {
|
||||
// pos
|
||||
std::vector<int> data(ne[2]);
|
||||
for (int i = 0; i < ne[2]; i++) {
|
||||
@ -1703,7 +1704,7 @@ struct test_llama : public test_llm {
|
||||
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_tokens);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
||||
@ -1825,7 +1826,7 @@ struct test_falcon : public test_llm {
|
||||
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_tokens);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
||||
|
@ -1449,9 +1449,9 @@ int main(int argc, const char ** argv) {
|
||||
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
|
||||
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||
|
||||
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
|
||||
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne2[2]);
|
||||
for (int i = 0; i < ne2[2]; ++i) {
|
||||
((int32_t *) p->data)[i] = n_past + i;
|
||||
((float *) p->data)[i] = n_past + i;
|
||||
}
|
||||
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
@ -1489,9 +1489,9 @@ int main(int argc, const char ** argv) {
|
||||
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
|
||||
x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||
|
||||
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
|
||||
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne2[2]);
|
||||
for (int i = 0; i < ne2[2]; ++i) {
|
||||
((int32_t *) p->data)[i] = n_past + i;
|
||||
((float *) p->data)[i] = n_past + i;
|
||||
}
|
||||
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
|
@ -146,14 +146,14 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||
const int n_past_0 = 100;
|
||||
const int n_past_2 = 33;
|
||||
|
||||
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
|
||||
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]);
|
||||
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]);
|
||||
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]);
|
||||
|
||||
for (int i = 0; i < ne[2]; ++i) {
|
||||
((int32_t *) p0->data)[i] = n_past_0 + i;
|
||||
((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
|
||||
((int32_t *) p2->data)[i] = n_past_2 + i;
|
||||
((float *) p0->data)[i] = n_past_0 + i;
|
||||
((float *) p1->data)[i] = n_past_2 - n_past_0;
|
||||
((float *) p2->data)[i] = n_past_2 + i;
|
||||
}
|
||||
|
||||
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
|
||||
|
Loading…
Reference in New Issue
Block a user