mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 22:30:32 +01:00
ggml : move rope type enum to ggml.h (#8949)
* ggml : move rope type enum to ggml.h
This commit moves the `llama_rope_type` enum from `llama.h` to
`ggml.h` and changes its name to `ggml_rope_type`.
The motivation for this change is to address the TODO in `llama.h` and
use the enum in ggml.
Note: This commit does not change the `mode` parameter to be of type
`enum ggml_rope_type`. The name `mode` and its usage suggest that it
might be more generic and possibly used as a bit field for multiple
flags. Further investigation/discussion may be needed to determine
if `mode` should be restricted to RoPE types.
* squash! ggml : move rope type enum to ggml.h
This commit removes GGML_ROPE_TYPE_NONE and GGML_ROPE_TYPE_GLM from
ggml.h, and back the llama_rope_type enum.
I've kept the assert for GGML_ROPE_TYPE_GLM as I'm not sure if it is
safe to remove it yet.
* squash! ggml : move rope type enum to ggml.h
This commit removes the enum ggml_rope_type from ggml.h and replaces it
with a define (GGML_ROPE_TYPE_NEOX). This define is used in the code to
check if the mode is set to GPT-NeoX. Also the enum llama_rope_type has
been updated to reflect this change.
* squash! ggml : move rope type enum to ggml.h
This commit contains a suggestion enable the GGML_ROPE_TYPE_NEOX
macro/define to be passed to the shader compiler.
* squash! ggml : move rope type enum to ggml.h
This commit fixes the editorconfig-checker warnings.
* squash! ggml : move rope type enum to ggml.h
Update comment for ggml_rope function.
* Revert "squash! ggml : move rope type enum to ggml.h"
This reverts commit 6261222bd0
.
* squash! ggml : move rope type enum to ggml.h
Add GGML_ROPE_TYPE_NEOX to rope_common.comp.
* remove extra line
---------
Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
828d6ff7d7
commit
06943a69f6
@ -244,6 +244,8 @@
|
||||
#define GGML_EXIT_SUCCESS 0
|
||||
#define GGML_EXIT_ABORTED 1
|
||||
|
||||
#define GGML_ROPE_TYPE_NEOX 2
|
||||
|
||||
#define GGUF_MAGIC "GGUF"
|
||||
|
||||
#define GGUF_VERSION 3
|
||||
@ -1453,8 +1455,8 @@ extern "C" {
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// rotary position embedding
|
||||
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
||||
// if mode & 2 == 1, GPT-NeoX style
|
||||
// if (mode & 1) - skip n_past elements (NOT SUPPORTED)
|
||||
// if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
|
||||
//
|
||||
// b is an int32 vector with size a->ne[2], it contains the positions
|
||||
GGML_API struct ggml_tensor * ggml_rope(
|
||||
|
@ -2881,7 +2881,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
|
||||
beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
// init cos/sin cache
|
||||
ggml_cann_pool_alloc sin_allocator(
|
||||
|
@ -226,7 +226,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1_d;
|
||||
|
||||
|
@ -2313,7 +2313,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
|
@ -226,7 +226,7 @@ void ggml_sycl_op_rope(
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1_dd;
|
||||
|
||||
|
@ -4053,7 +4053,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int mode = ((const int32_t *) dst->op_params)[2];
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
|
@ -14094,7 +14094,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
@ -14219,7 +14219,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
|
@ -11,7 +11,7 @@ void main() {
|
||||
const uint i2 = gl_WorkGroupID.y;
|
||||
const uint i1 = gl_WorkGroupID.x;
|
||||
|
||||
const bool is_neox = (pcs.mode & 2) != 0;
|
||||
const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
||||
|
@ -11,7 +11,7 @@ void main() {
|
||||
const uint i2 = gl_WorkGroupID.y;
|
||||
const uint i1 = gl_WorkGroupID.x;
|
||||
|
||||
const bool is_neox = (pcs.mode & 2) != 0;
|
||||
const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include "common.comp"
|
||||
|
||||
#define GGML_ROPE_TYPE_NEOX 2
|
||||
|
||||
// TODO: use a local size of 32 or more (Metal uses 1024)
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
|
@ -95,13 +95,10 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
||||
enum llama_rope_type {
|
||||
LLAMA_ROPE_TYPE_NONE = -1,
|
||||
LLAMA_ROPE_TYPE_NORM = 0,
|
||||
LLAMA_ROPE_TYPE_NEOX = 2,
|
||||
LLAMA_ROPE_TYPE_GLM = 4,
|
||||
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
||||
};
|
||||
|
||||
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
|
||||
|
Loading…
Reference in New Issue
Block a user