mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
CUDA: fix RoPE asserts, block sizes (#2833)
This commit is contained in:
parent
dd0dc366da
commit
92b1bbd2ec
@ -4908,8 +4908,8 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
|||||||
|
|
||||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||||
GGML_ASSERT(nrows % 2 == 0); // GG: is this assert really needed? I don't see why
|
GGML_ASSERT(ncols % 2 == 0);
|
||||||
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||||
@ -4917,7 +4917,8 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
|
|||||||
|
|
||||||
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||||
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
|
GGML_ASSERT(ncols % 2 == 0);
|
||||||
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user