mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
cuda : add RoPE kernel for mode == 2 (NeoX) (#2760)
* cuda : add RoPE kernel for mode == 2 (NeoX) * falcon : do not offload the embeddings layer
This commit is contained in:
parent
87e3733f24
commit
3f460a2b72
58
ggml-cuda.cu
58
ggml-cuda.cu
@ -3907,28 +3907,27 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||||||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this implementation is wrong!
|
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||||
//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
|
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||||
// const float p_delta, const int p_delta_rows, const float theta_scale) {
|
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
||||||
//
|
if (col >= ncols) {
|
||||||
// if (col >= ncols) {
|
return;
|
||||||
// return;
|
}
|
||||||
// }
|
|
||||||
//
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = row*ncols + col/2;
|
||||||
// const int i = row*ncols + col/2;
|
|
||||||
//
|
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
||||||
// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
const float sin_theta = sinf(theta);
|
||||||
// const float sin_theta = sinf(theta);
|
const float cos_theta = cosf(theta);
|
||||||
// const float cos_theta = cosf(theta);
|
|
||||||
//
|
const float x0 = x[i + 0];
|
||||||
// const float x0 = x[i + 0];
|
const float x1 = x[i + ncols/2];
|
||||||
// const float x1 = x[i + ncols/2];
|
|
||||||
//
|
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||||
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
||||||
// dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
}
|
||||||
//}
|
|
||||||
|
|
||||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
|
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
|
||||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
@ -4799,13 +4798,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
|||||||
|
|
||||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||||
GGML_ASSERT(nrows % 2 == 0);
|
GGML_ASSERT(nrows % 2 == 0); // GG: is this assert really needed? I don't see why
|
||||||
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||||
|
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||||
|
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
|
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
|
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||||
|
}
|
||||||
|
|
||||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
|
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
|
||||||
GGML_ASSERT(nrows % 4 == 0);
|
GGML_ASSERT(nrows % 4 == 0);
|
||||||
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
||||||
@ -5548,8 +5555,9 @@ inline void ggml_cuda_op_rope(
|
|||||||
const float block_p = max(p - (n_ctx - 2.f), 0.f);
|
const float block_p = max(p - (n_ctx - 2.f), 0.f);
|
||||||
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
|
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
|
||||||
} else if (is_neox) {
|
} else if (is_neox) {
|
||||||
GGML_ASSERT(false && "RoPE NeoX not implemented yet");
|
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
||||||
#pragma message("TODO: implement RoPE NeoX for CUDA")
|
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||||
|
rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
||||||
} else {
|
} else {
|
||||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||||
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
||||||
|
20
llama.cpp
20
llama.cpp
@ -1958,6 +1958,14 @@ static void llm_load_tensors(
|
|||||||
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||||
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
|
||||||
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||||
|
|
||||||
|
if (backend_norm == GGML_BACKEND_GPU) {
|
||||||
|
vram_weights += ggml_nbytes(model.output_norm);
|
||||||
|
vram_weights += ggml_nbytes(model.output_norm_b);
|
||||||
|
}
|
||||||
|
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||||
|
vram_weights += ggml_nbytes(model.output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_ff = hparams.n_ff;
|
const uint32_t n_ff = hparams.n_ff;
|
||||||
@ -1978,6 +1986,11 @@ static void llm_load_tensors(
|
|||||||
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
|
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
|
||||||
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
|
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
|
||||||
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend);
|
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend);
|
||||||
|
|
||||||
|
if (backend == GGML_BACKEND_GPU) {
|
||||||
|
vram_weights += ggml_nbytes(layer.attn_norm_2);
|
||||||
|
vram_weights += ggml_nbytes(layer.attn_norm_2_b);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||||
@ -1985,6 +1998,13 @@ static void llm_load_tensors(
|
|||||||
|
|
||||||
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||||
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
|
||||||
|
if (backend == GGML_BACKEND_GPU) {
|
||||||
|
vram_weights +=
|
||||||
|
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
|
||||||
|
ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) +
|
||||||
|
ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
Loading…
Reference in New Issue
Block a user