mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-23 21:17:54 +01:00
llama : implement YaRN RoPE scaling (#2268)
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com> Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
This commit is contained in:
parent
c43c2da8af
commit
898aeca90a
@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
break;
|
||||
}
|
||||
params.rope_freq_scale = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-scaling") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
|
||||
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
|
||||
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
|
||||
else { invalid_param = true; break; }
|
||||
} else if (arg == "--rope-scale") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
|
||||
} else if (arg == "--yarn-orig-ctx") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_orig_ctx = std::stoi(argv[i]);
|
||||
} else if (arg == "--yarn-ext-factor") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_ext_factor = std::stof(argv[i]);
|
||||
} else if (arg == "--yarn-attn-factor") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_attn_factor = std::stof(argv[i]);
|
||||
} else if (arg == "--yarn-beta-fast") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_beta_fast = std::stof(argv[i]);
|
||||
} else if (arg == "--yarn-beta-slow") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
} else if (arg == "--memory-f32") {
|
||||
params.memory_f16 = false;
|
||||
} else if (arg == "--top-p") {
|
||||
@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
printf(" --cfg-negative-prompt-file FNAME\n");
|
||||
printf(" negative prompt file to use for guidance. (default: empty)\n");
|
||||
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
|
||||
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
|
||||
printf(" --rope-scaling {none,linear,yarn}\n");
|
||||
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
|
||||
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
|
||||
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
|
||||
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
|
||||
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
|
||||
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
|
||||
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
|
||||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
||||
printf(" --no-penalize-nl do not penalize newline token\n");
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
@ -826,17 +873,23 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
|
||||
auto cparams = llama_context_default_params();
|
||||
|
||||
cparams.n_ctx = params.n_ctx;
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.seed = params.seed;
|
||||
cparams.f16_kv = params.memory_f16;
|
||||
cparams.logits_all = params.logits_all;
|
||||
cparams.embedding = params.embedding;
|
||||
cparams.rope_freq_base = params.rope_freq_base;
|
||||
cparams.rope_freq_scale = params.rope_freq_scale;
|
||||
cparams.n_ctx = params.n_ctx;
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.seed = params.seed;
|
||||
cparams.f16_kv = params.memory_f16;
|
||||
cparams.logits_all = params.logits_all;
|
||||
cparams.embedding = params.embedding;
|
||||
cparams.rope_scaling_type = params.rope_scaling_type;
|
||||
cparams.rope_freq_base = params.rope_freq_base;
|
||||
cparams.rope_freq_scale = params.rope_freq_scale;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
|
||||
return cparams;
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
#define LOG_NO_FILE_LINE_FUNCTION
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
@ -54,6 +55,12 @@ struct gpt_params {
|
||||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = 32.0f;// YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
|
||||
|
||||
// // sampling parameters
|
||||
struct llama_sampling_params sparams;
|
||||
|
@ -163,7 +163,8 @@ gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
|
||||
if "type" in hparams["rope_scaling"]:
|
||||
if hparams["rope_scaling"]["type"] == "linear":
|
||||
gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
|
||||
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
|
||||
|
||||
|
||||
# TOKENIZATION
|
||||
|
97
convert.py
97
convert.py
@ -151,8 +151,11 @@ class Params:
|
||||
n_head_kv: int
|
||||
f_norm_eps: float
|
||||
|
||||
rope_scaling_type: gguf.RopeScalingType | None = None
|
||||
f_rope_freq_base: float | None = None
|
||||
f_rope_scale: float | None = None
|
||||
n_orig_ctx: int | None = None
|
||||
rope_finetuned: bool | None = None
|
||||
|
||||
ftype: GGMLFileType | None = None
|
||||
|
||||
@ -198,20 +201,20 @@ class Params:
|
||||
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
|
||||
config = json.load(open(config_path))
|
||||
|
||||
n_vocab = config["vocab_size"]
|
||||
n_embd = config["hidden_size"]
|
||||
n_layer = config["num_hidden_layers"]
|
||||
n_ff = config["intermediate_size"]
|
||||
n_head = config["num_attention_heads"]
|
||||
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
|
||||
f_norm_eps = config["rms_norm_eps"]
|
||||
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
|
||||
|
||||
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
|
||||
rope_scaling = config.get("rope_scaling")
|
||||
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
|
||||
f_rope_scale = config["rope_scaling"].get("factor")
|
||||
else:
|
||||
f_rope_scale = None
|
||||
|
||||
if rope_scaling is not None and (typ := rope_scaling.get("type")):
|
||||
rope_factor = rope_scaling.get("factor")
|
||||
f_rope_scale = rope_factor
|
||||
if typ == "linear":
|
||||
rope_scaling_type = gguf.RopeScalingType.LINEAR
|
||||
elif typ == "yarn":
|
||||
rope_scaling_type = gguf.RopeScalingType.YARN
|
||||
n_orig_ctx = rope_scaling['original_max_position_embeddings']
|
||||
rope_finetuned = rope_scaling['finetuned']
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown rope scaling type: {typ}')
|
||||
|
||||
if "max_sequence_length" in config:
|
||||
n_ctx = config["max_sequence_length"]
|
||||
@ -222,16 +225,19 @@ class Params:
|
||||
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_layer = n_layer,
|
||||
n_ctx = n_ctx,
|
||||
n_ff = n_ff,
|
||||
n_head = n_head,
|
||||
n_head_kv = n_head_kv,
|
||||
f_norm_eps = f_norm_eps,
|
||||
f_rope_freq_base = f_rope_freq_base,
|
||||
f_rope_scale = f_rope_scale,
|
||||
n_vocab = config["vocab_size"],
|
||||
n_embd = config["hidden_size"],
|
||||
n_layer = config["num_hidden_layers"],
|
||||
n_ctx = n_ctx,
|
||||
n_ff = config["intermediate_size"],
|
||||
n_head = (n_head := config["num_attention_heads"]),
|
||||
n_head_kv = config.get("num_key_value_heads", n_head),
|
||||
f_norm_eps = config["rms_norm_eps"],
|
||||
f_rope_freq_base = config.get("rope_theta"),
|
||||
rope_scaling_type = rope_scaling_type,
|
||||
f_rope_scale = f_rope_scale,
|
||||
n_orig_ctx = n_orig_ctx,
|
||||
rope_finetuned = rope_finetuned,
|
||||
)
|
||||
|
||||
# LLaMA v2 70B params.json
|
||||
@ -240,17 +246,8 @@ class Params:
|
||||
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
|
||||
config = json.load(open(config_path))
|
||||
|
||||
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
|
||||
n_embd = config["dim"]
|
||||
n_layer = config["n_layers"]
|
||||
n_ff = -1
|
||||
n_head = config["n_heads"]
|
||||
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
|
||||
f_norm_eps = config["norm_eps"]
|
||||
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
|
||||
|
||||
# hack to determine LLaMA v1 vs v2 vs CodeLlama
|
||||
if f_rope_freq_base == 1000000:
|
||||
if config.get("rope_theta") == 1000000:
|
||||
# CodeLlama
|
||||
n_ctx = 16384
|
||||
elif config["norm_eps"] == 1e-05:
|
||||
@ -260,22 +257,16 @@ class Params:
|
||||
# LLaMA v1
|
||||
n_ctx = 2048
|
||||
|
||||
if n_vocab == -1:
|
||||
n_vocab = model["tok_embeddings.weight"].shape[0]
|
||||
|
||||
if n_ff == -1:
|
||||
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
|
||||
return Params(
|
||||
n_vocab = n_vocab,
|
||||
n_embd = n_embd,
|
||||
n_layer = n_layer,
|
||||
n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
|
||||
n_embd = config["dim"],
|
||||
n_layer = config["n_layers"],
|
||||
n_ctx = n_ctx,
|
||||
n_ff = n_ff,
|
||||
n_head = n_head,
|
||||
n_head_kv = n_head_kv,
|
||||
f_norm_eps = f_norm_eps,
|
||||
f_rope_freq_base = f_rope_freq_base,
|
||||
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
|
||||
n_head = (n_head := config["n_heads"]),
|
||||
n_head_kv = config.get("n_kv_heads", n_head),
|
||||
f_norm_eps = config["norm_eps"],
|
||||
f_rope_freq_base = config.get("rope_theta"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -831,8 +822,16 @@ class OutputFile:
|
||||
if params.f_rope_freq_base is not None:
|
||||
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
|
||||
|
||||
if params.f_rope_scale is not None:
|
||||
self.gguf.add_rope_scale_linear(params.f_rope_scale)
|
||||
if params.rope_scaling_type:
|
||||
assert params.f_rope_scale is not None
|
||||
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
|
||||
self.gguf.add_rope_scaling_factor(params.f_rope_scale)
|
||||
|
||||
if params.n_orig_ctx is not None:
|
||||
self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
|
||||
|
||||
if params.rope_finetuned is not None:
|
||||
self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
|
||||
|
||||
if params.ftype is not None:
|
||||
self.gguf.add_file_type(params.ftype)
|
||||
|
@ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||
const int rope_mode = 0;
|
||||
|
||||
return ggml_rope_custom(ctx,
|
||||
t, KQ_pos, n_rot, rope_mode, n_ctx,
|
||||
rope_freq_base, rope_freq_scale);
|
||||
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
|
||||
rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
|
||||
);
|
||||
};
|
||||
|
||||
set_name(tokens_input, "tokens_input");
|
||||
|
@ -1755,12 +1755,18 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||
printf("options:\n");
|
||||
printf(" -h, --help show this help message and exit\n");
|
||||
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
||||
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
|
||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
printf(" --rope-scaling {none,linear,yarn}\n");
|
||||
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
|
||||
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
|
||||
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
|
||||
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
|
||||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
if (llama_mlock_supported())
|
||||
@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
}
|
||||
params.n_ctx = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-scaling")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
|
||||
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
|
||||
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
|
||||
else { invalid_param = true; break; }
|
||||
}
|
||||
else if (arg == "--rope-freq-base")
|
||||
{
|
||||
if (++i >= argc)
|
||||
@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
}
|
||||
params.rope_freq_scale = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-ext-factor")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_ext_factor = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-attn-factor")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_attn_factor = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-beta-fast")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_beta_fast = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--yarn-beta-slow")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--memory-f32" || arg == "--memory_f32")
|
||||
{
|
||||
params.memory_f16 = false;
|
||||
|
@ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs(
|
||||
// not capturing these, to silcence warnings
|
||||
const int rope_mode = 0;
|
||||
|
||||
return ggml_rope_custom(ctx,
|
||||
t, KQ_pos, n_rot, rope_mode, n_ctx,
|
||||
rope_freq_base, rope_freq_scale);
|
||||
return ggml_rope_custom(
|
||||
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
);
|
||||
};
|
||||
|
||||
set_name(tokens_input, "tokens_input");
|
||||
|
153
ggml-cuda.cu
153
ggml-cuda.cu
@ -4493,11 +4493,41 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
struct rope_corr_dims {
|
||||
float v[4];
|
||||
};
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static __device__ void rope_yarn(
|
||||
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta
|
||||
) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = cosf(theta) * mscale;
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
template<typename T, bool has_pos>
|
||||
static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
static __global__ void rope(
|
||||
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
@ -4509,10 +4539,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float p0 = p*freq_scale;
|
||||
const float theta = p0*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
const float theta_base = p*powf(freq_base, -col/ncols);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + 1];
|
||||
@ -4522,8 +4552,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
|
||||
}
|
||||
|
||||
template<typename T, bool has_pos>
|
||||
static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
static __global__ void rope_neox(
|
||||
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
@ -4534,11 +4566,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
|
||||
const int i = row*ncols + col/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
// simplified from `(row * ncols + col) * (-1 / ncols)`
|
||||
const float cur_rot = -col/ncols - row;
|
||||
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float p0 = p*freq_scale;
|
||||
const float theta = p0*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
const float theta_base = p*powf(freq_base, cur_rot);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const float x0 = x[i + 0];
|
||||
const float x1 = x[i + ncols/2];
|
||||
@ -4547,8 +4582,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
|
||||
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, const int n_ctx) {
|
||||
static __global__ void rope_glm_f32(
|
||||
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
||||
int n_ctx
|
||||
) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int half_n_dims = ncols/4;
|
||||
|
||||
@ -4560,7 +4597,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
||||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float col_theta_scale = powf(theta_scale, col);
|
||||
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
|
||||
// FIXME: this is likely wrong
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
|
||||
@ -5584,40 +5621,54 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
static void rope_cuda(
|
||||
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
if (pos == nullptr) {
|
||||
rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
rope<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
|
||||
);
|
||||
} else {
|
||||
rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
rope<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
static void rope_neox_cuda(
|
||||
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
if (pos == nullptr) {
|
||||
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
|
||||
);
|
||||
} else {
|
||||
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
|
||||
static void rope_glm_f32_cuda(
|
||||
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, int n_ctx, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 4 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
||||
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
|
||||
}
|
||||
|
||||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||
@ -6477,17 +6528,20 @@ inline void ggml_cuda_op_rope(
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
// RoPE alteration for extended context
|
||||
|
||||
float freq_base, freq_scale;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const int32_t * pos = nullptr;
|
||||
if ((mode & 1) == 0) {
|
||||
@ -6499,24 +6553,39 @@ inline void ggml_cuda_op_rope(
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
// compute
|
||||
if (is_glm) {
|
||||
GGML_ASSERT(false);
|
||||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
|
||||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
|
||||
} else if (is_neox) {
|
||||
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
rope_neox_cuda(
|
||||
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, main_stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
rope_neox_cuda(
|
||||
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, main_stream
|
||||
);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
rope_cuda(
|
||||
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, main_stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
rope_cuda(
|
||||
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, main_stream
|
||||
);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
22
ggml-metal.m
22
ggml-metal.m
@ -1400,14 +1400,18 @@ void ggml_metal_graph_compute(
|
||||
|
||||
const int nth = MIN(1024, ne00);
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
||||
@ -1439,6 +1443,10 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
||||
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
|
||||
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
|
||||
[encoder setBytes:&ext_factor length:sizeof(float) atIndex:24];
|
||||
[encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
|
||||
[encoder setBytes:&beta_fast length:sizeof(float) atIndex:26];
|
||||
[encoder setBytes:&beta_slow length:sizeof(float) atIndex:27];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
@ -1061,6 +1061,45 @@ kernel void kernel_alibi_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta
|
||||
) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = cosf(theta) * mscale;
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
}
|
||||
|
||||
static void rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||
}
|
||||
|
||||
typedef void (rope_t)(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
@ -1116,6 +1155,10 @@ kernel void kernel_rope(
|
||||
constant int & mode,
|
||||
constant float & freq_base,
|
||||
constant float & freq_scale,
|
||||
constant float & ext_factor,
|
||||
constant float & attn_factor,
|
||||
constant float & beta_fast,
|
||||
constant float & beta_slow,
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
@ -1125,19 +1168,22 @@ kernel void kernel_rope(
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = src1;
|
||||
|
||||
const int64_t p = pos[i2];
|
||||
|
||||
const float theta_0 = freq_scale * (float)p;
|
||||
const float theta_0 = (float)p;
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
||||
const float cos_theta = cos(theta);
|
||||
const float sin_theta = sin(theta);
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
@ -1152,9 +1198,12 @@ kernel void kernel_rope(
|
||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
||||
const float cos_theta = cos(theta);
|
||||
const float sin_theta = sin(theta);
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
const float cur_rot = inv_ndims*ic - ib;
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, cur_rot);
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
|
||||
|
241
ggml.c
241
ggml.c
@ -1,4 +1,5 @@
|
||||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
|
||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
@ -4845,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
float xpos_base,
|
||||
bool xpos_down,
|
||||
bool inplace) {
|
||||
@ -4862,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
||||
memcpy(params + 4, &freq_base, sizeof(float));
|
||||
memcpy(params + 5, &freq_scale, sizeof(float));
|
||||
memcpy(params + 6, &xpos_base, sizeof(float));
|
||||
memcpy(params + 7, &xpos_down, sizeof(bool));
|
||||
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(params + 11, &xpos_base, sizeof(float));
|
||||
memcpy(params + 12, &xpos_down, sizeof(bool));
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE;
|
||||
@ -4884,7 +4894,9 @@ struct ggml_tensor * ggml_rope(
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_inplace(
|
||||
@ -4894,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_custom(
|
||||
@ -4904,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom(
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale) {
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
@ -4916,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale) {
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
@ -4928,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down) {
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||
}
|
||||
|
||||
// ggml_rope_back
|
||||
@ -10901,6 +10931,45 @@ static void ggml_compute_forward_clamp(
|
||||
|
||||
// ggml_compute_forward_rope
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
||||
return 1 - MIN(1, MAX(0, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta
|
||||
) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = cosf(theta) * mscale;
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||
}
|
||||
|
||||
void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
@ -10910,21 +10979,26 @@ static void ggml_compute_forward_rope_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
// these two only relevant for xPos RoPE:
|
||||
float xpos_base;
|
||||
bool xpos_down;
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@ -10952,6 +11026,9 @@ static void ggml_compute_forward_rope_f32(
|
||||
int ir = 0;
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
@ -10965,18 +11042,18 @@ static void ggml_compute_forward_rope_f32(
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta = freq_scale * (float)p;
|
||||
float theta_base = (float)p;
|
||||
|
||||
if (is_glm) {
|
||||
theta = MIN(p, n_ctx - 2);
|
||||
theta_base = MIN(p, n_ctx - 2);
|
||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base);
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
const float sin_block_theta = sinf(block_theta);
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
block_theta *= theta_scale;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
@ -10994,13 +11071,16 @@ static void ggml_compute_forward_rope_f32(
|
||||
}
|
||||
} else if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
||||
);
|
||||
|
||||
// zeta scaling for xPos only:
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
||||
if (xpos_down) zeta = 1.0f / zeta;
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
@ -11014,12 +11094,19 @@ static void ggml_compute_forward_rope_f32(
|
||||
} else {
|
||||
// TODO: this might be wrong for ne0 != n_dims - need double check
|
||||
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
||||
theta_base *= freq_scale;
|
||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
|
||||
theta *= theta_scale;
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
|
||||
@ -11048,15 +11135,19 @@ static void ggml_compute_forward_rope_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@ -11084,6 +11175,9 @@ static void ggml_compute_forward_rope_f16(
|
||||
int ir = 0;
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
@ -11097,18 +11191,18 @@ static void ggml_compute_forward_rope_f16(
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta = freq_scale * (float)p;
|
||||
float theta_base = (float)p;
|
||||
|
||||
if (is_glm) {
|
||||
theta = MIN(p, n_ctx - 2);
|
||||
theta_base = MIN(p, n_ctx - 2);
|
||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base);
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
const float sin_block_theta = sinf(block_theta);
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
block_theta *= theta_scale;
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
@ -11126,10 +11220,12 @@ static void ggml_compute_forward_rope_f16(
|
||||
}
|
||||
} else if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
||||
);
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
@ -11143,12 +11239,19 @@ static void ggml_compute_forward_rope_f16(
|
||||
} else {
|
||||
// TODO: this might be wrong for ne0 != n_dims - need double check
|
||||
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
||||
theta_base *= freq_scale;
|
||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
|
||||
theta *= theta_scale;
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
|
||||
@ -11256,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32(
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta = freq_scale * (float)p;
|
||||
float theta_base = freq_scale * (float)p;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base);
|
||||
|
||||
// zeta scaling for xPos only:
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
||||
if (xpos_down) zeta = 1.0f / zeta;
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
@ -11280,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32(
|
||||
} else {
|
||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base);
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
|
||||
@ -11356,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16(
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta = (float)p;
|
||||
float theta_base = (float)p;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base);
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
@ -11377,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16(
|
||||
} else {
|
||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base);
|
||||
|
||||
theta *= theta_scale;
|
||||
theta_base *= theta_scale;
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
|
||||
@ -15505,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
src1,
|
||||
n_dims,
|
||||
mode,
|
||||
0,
|
||||
n_ctx,
|
||||
freq_base,
|
||||
freq_scale,
|
||||
0.0f,
|
||||
1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
xpos_base,
|
||||
xpos_down,
|
||||
false),
|
||||
|
20
ggml.h
20
ggml.h
@ -219,7 +219,7 @@
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_SRC 6
|
||||
#define GGML_MAX_NAME 64
|
||||
#define GGML_MAX_OP_PARAMS 32
|
||||
#define GGML_MAX_OP_PARAMS 64
|
||||
#define GGML_DEFAULT_N_THREADS 4
|
||||
|
||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||
@ -1326,8 +1326,13 @@ extern "C" {
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale);
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
@ -1337,8 +1342,17 @@ extern "C" {
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale);
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
// compute correction dims for YaRN RoPE scaling
|
||||
void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
|
||||
// xPos RoPE, in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
|
@ -7,7 +7,7 @@ import shutil
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
from enum import IntEnum, auto
|
||||
from enum import Enum, IntEnum, auto
|
||||
from io import BufferedWriter
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, BinaryIO, Callable, Sequence
|
||||
@ -53,9 +53,12 @@ KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
|
||||
# RoPE
|
||||
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
|
||||
KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear"
|
||||
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
|
||||
KEY_ROPE_SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
KEY_ROPE_SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
KEY_ROPE_SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
KEY_ROPE_SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
|
||||
# tokenization
|
||||
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
|
||||
@ -577,6 +580,11 @@ class TokenType(IntEnum):
|
||||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
class RopeScalingType(Enum):
|
||||
NONE = 'none'
|
||||
LINEAR = 'linear'
|
||||
YARN = 'yarn'
|
||||
|
||||
#
|
||||
# implementation
|
||||
#
|
||||
@ -948,8 +956,17 @@ class GGUFWriter:
|
||||
def add_rope_freq_base(self, value: float):
|
||||
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scale_linear(self, value: float):
|
||||
self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
|
||||
def add_rope_scaling_type(self, value: RopeScalingType):
|
||||
self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
def add_rope_scaling_factor(self, value: float):
|
||||
self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_orig_ctx_len(self, value: int):
|
||||
self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_finetuned(self, value: bool):
|
||||
self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
|
||||
|
||||
def add_tokenizer_model(self, model: str):
|
||||
self.add_string(KEY_TOKENIZER_MODEL, model)
|
||||
|
220
llama.cpp
220
llama.cpp
@ -54,6 +54,7 @@
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstdarg>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
@ -235,6 +236,10 @@ enum llm_kv {
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||
|
||||
LLM_KV_TOKENIZER_MODEL,
|
||||
LLM_KV_TOKENIZER_LIST,
|
||||
@ -276,9 +281,13 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
@ -552,6 +561,22 @@ do { \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
|
||||
{ LLAMA_ROPE_SCALING_NONE, "none" },
|
||||
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
|
||||
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
|
||||
};
|
||||
|
||||
static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
|
||||
if (kv.second == name) {
|
||||
return kv.first;
|
||||
}
|
||||
}
|
||||
|
||||
return LLAMA_ROPE_SCALING_UNSPECIFIED;
|
||||
}
|
||||
|
||||
//
|
||||
// ggml helpers
|
||||
//
|
||||
@ -1035,8 +1060,11 @@ struct llama_hparams {
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
int8_t rope_scaling_type_train : 3;
|
||||
bool rope_finetuned : 1;
|
||||
|
||||
float f_clamp_kqv;
|
||||
float f_max_alibi_bias;
|
||||
@ -1051,6 +1079,8 @@ struct llama_hparams {
|
||||
if (this->n_layer != other.n_layer) return true;
|
||||
if (this->n_rot != other.n_rot) return true;
|
||||
if (this->n_ff != other.n_ff) return true;
|
||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
||||
|
||||
const float EPSILON = 1e-9;
|
||||
|
||||
@ -1081,8 +1111,16 @@ struct llama_cparams {
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
// These hyperparameters are not exposed in GGUF, because all
|
||||
// existing YaRN models use the same values for them.
|
||||
float yarn_ext_factor;
|
||||
float yarn_attn_factor;
|
||||
float yarn_beta_fast;
|
||||
float yarn_beta_slow;
|
||||
|
||||
bool mul_mat_q;
|
||||
};
|
||||
@ -2014,14 +2052,30 @@ static void llm_load_hparams(
|
||||
hparams.n_head_kv = hparams.n_head;
|
||||
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
|
||||
|
||||
hparams.rope_finetuned = false;
|
||||
GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
|
||||
kv(LLM_KV_ROPE_SCALING_FINETUNED));
|
||||
|
||||
hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
|
||||
GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
|
||||
kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
|
||||
|
||||
// rope_freq_base (optional)
|
||||
hparams.rope_freq_base_train = 10000.0f;
|
||||
GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
||||
|
||||
std::string rope_scaling("linear");
|
||||
GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
|
||||
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
|
||||
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
|
||||
|
||||
// rope_freq_scale (inverse of the kv) is optional
|
||||
float ropescale = 1.0f;
|
||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||
hparams.rope_freq_scale_train = 1.0f/ropescale;
|
||||
float ropescale = 0.0f;
|
||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
|
||||
if (ropescale == 0.0f) { // try the old key name
|
||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||
}
|
||||
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
{
|
||||
@ -2371,6 +2425,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & vocab = model.vocab;
|
||||
|
||||
const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
|
||||
|
||||
// hparams
|
||||
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
|
||||
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
|
||||
@ -2389,8 +2445,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
||||
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
|
||||
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
|
||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
|
||||
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
||||
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
|
||||
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
|
||||
@ -3047,21 +3106,11 @@ static void llm_load_tensors(
|
||||
model.t_load_us = ggml_time_us() - model.t_start_us;
|
||||
}
|
||||
|
||||
static bool llama_model_load(
|
||||
const std::string & fname,
|
||||
llama_model & model,
|
||||
int n_gpu_layers,
|
||||
int main_gpu,
|
||||
const float * tensor_split,
|
||||
bool use_mmap,
|
||||
bool use_mlock,
|
||||
bool vocab_only,
|
||||
llama_progress_callback progress_callback,
|
||||
void *progress_callback_user_data) {
|
||||
static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
|
||||
try {
|
||||
llama_model_loader ml(fname, use_mmap);
|
||||
llama_model_loader ml(fname, params.use_mmap);
|
||||
|
||||
model.hparams.vocab_only = vocab_only;
|
||||
model.hparams.vocab_only = params.vocab_only;
|
||||
|
||||
llm_load_arch (ml, model);
|
||||
llm_load_hparams(ml, model);
|
||||
@ -3073,15 +3122,15 @@ static bool llama_model_load(
|
||||
throw std::runtime_error("vocab size mismatch");
|
||||
}
|
||||
|
||||
if (vocab_only) {
|
||||
if (params.vocab_only) {
|
||||
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
||||
return true;
|
||||
}
|
||||
|
||||
llm_load_tensors(
|
||||
ml, model, n_gpu_layers,
|
||||
main_gpu, tensor_split,
|
||||
use_mlock, progress_callback, progress_callback_user_data);
|
||||
ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
|
||||
params.progress_callback, params.progress_callback_user_data
|
||||
);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
|
||||
return false;
|
||||
@ -3150,6 +3199,7 @@ static struct ggml_tensor * llm_build_inp_embd(
|
||||
static void llm_build_k_shift(
|
||||
struct ggml_context * ctx,
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache & kv,
|
||||
struct ggml_cgraph * graph,
|
||||
llm_rope_type type,
|
||||
@ -3162,6 +3212,11 @@ static void llm_build_k_shift(
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
|
||||
const float ext_factor = cparams.yarn_ext_factor;
|
||||
const float attn_factor = cparams.yarn_attn_factor;
|
||||
const float beta_fast = cparams.yarn_beta_fast;
|
||||
const float beta_slow = cparams.yarn_beta_slow;
|
||||
|
||||
GGML_ASSERT(n_embd_head % n_rot == 0);
|
||||
|
||||
@ -3185,7 +3240,8 @@ static void llm_build_k_shift(
|
||||
ggml_element_size(kv.k)*n_embd_head,
|
||||
ggml_element_size(kv.k)*n_embd_gqa,
|
||||
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
|
||||
K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
|
||||
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(graph, tmp);
|
||||
}
|
||||
@ -3442,12 +3498,17 @@ struct llm_build_context {
|
||||
|
||||
const float freq_base;
|
||||
const float freq_scale;
|
||||
const float ext_factor;
|
||||
const float attn_factor;
|
||||
const float beta_fast;
|
||||
const float beta_slow;
|
||||
const float norm_eps;
|
||||
const float norm_rms_eps;
|
||||
|
||||
const int32_t n_tokens;
|
||||
const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
|
||||
const int32_t kv_head; // index of where we store new KV data in the cache
|
||||
const int32_t n_orig_ctx;
|
||||
|
||||
const bool do_rope_shift;
|
||||
|
||||
@ -3477,11 +3538,16 @@ struct llm_build_context {
|
||||
n_embd_gqa (hparams.n_embd_gqa()),
|
||||
freq_base (cparams.rope_freq_base),
|
||||
freq_scale (cparams.rope_freq_scale),
|
||||
ext_factor (cparams.yarn_ext_factor),
|
||||
attn_factor (cparams.yarn_attn_factor),
|
||||
beta_fast (cparams.yarn_beta_fast),
|
||||
beta_slow (cparams.yarn_beta_slow),
|
||||
norm_eps (hparams.f_norm_eps),
|
||||
norm_rms_eps (hparams.f_norm_rms_eps),
|
||||
n_tokens (batch.n_tokens),
|
||||
n_kv (worst_case ? n_ctx : kv_self.n),
|
||||
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
|
||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||
do_rope_shift (worst_case || kv_self.has_shift),
|
||||
cb (cb),
|
||||
buf_compute (lctx.buf_compute) {
|
||||
@ -3532,7 +3598,7 @@ struct llm_build_context {
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -3556,10 +3622,18 @@ struct llm_build_context {
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
@ -3634,7 +3708,7 @@ struct llm_build_context {
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -3658,8 +3732,16 @@ struct llm_build_context {
|
||||
|
||||
switch (model.type) {
|
||||
case MODEL_7B:
|
||||
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
break;
|
||||
case MODEL_13B:
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
|
||||
@ -3746,7 +3828,7 @@ struct llm_build_context {
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -3786,10 +3868,16 @@ struct llm_build_context {
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
@ -3960,7 +4048,7 @@ struct llm_build_context {
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@ -4053,13 +4141,15 @@ struct llm_build_context {
|
||||
cb(kpass, "kpass", il);
|
||||
|
||||
struct ggml_tensor * qrotated = ggml_rope_custom(
|
||||
ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
|
||||
);
|
||||
ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(qrotated, "qrotated", il);
|
||||
|
||||
struct ggml_tensor * krotated = ggml_rope_custom(
|
||||
ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
|
||||
);
|
||||
ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(krotated, "krotated", il);
|
||||
|
||||
// ggml currently only supports concatenation on dim=2
|
||||
@ -7883,8 +7973,13 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.n_batch =*/ 512,
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.yarn_ext_factor =*/ NAN,
|
||||
/*.yarn_attn_factor =*/ 1.0f,
|
||||
/*.yarn_beta_fast =*/ 32.0f,
|
||||
/*.yarn_beta_slow =*/ 1.0f,
|
||||
/*.mul_mat_q =*/ true,
|
||||
/*.f16_kv =*/ true,
|
||||
/*.logits_all =*/ false,
|
||||
@ -7971,10 +8066,7 @@ struct llama_model * llama_load_model_from_file(
|
||||
};
|
||||
}
|
||||
|
||||
if (!llama_model_load(path_model, *model, params.n_gpu_layers,
|
||||
params.main_gpu, params.tensor_split,
|
||||
params.use_mmap, params.use_mlock, params.vocab_only,
|
||||
params.progress_callback, params.progress_callback_user_data)) {
|
||||
if (!llama_model_load(path_model, *model, params)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
|
||||
delete model;
|
||||
return nullptr;
|
||||
@ -8000,13 +8092,35 @@ struct llama_context * llama_new_context_with_model(
|
||||
const auto & hparams = model->hparams;
|
||||
auto & cparams = ctx->cparams;
|
||||
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||
|
||||
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
||||
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
|
||||
hparams.n_ctx_train;
|
||||
|
||||
auto rope_scaling_type = params.rope_scaling_type;
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
|
||||
rope_scaling_type = hparams.rope_scaling_type_train;
|
||||
}
|
||||
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
|
||||
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
|
||||
}
|
||||
|
||||
if (std::isnan(cparams.yarn_ext_factor)) { // NaN indicates 'not set'
|
||||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
|
18
llama.h
18
llama.h
@ -106,6 +106,14 @@ extern "C" {
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
||||
enum llama_rope_scaling_type {
|
||||
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
|
||||
LLAMA_ROPE_SCALING_NONE = 0,
|
||||
LLAMA_ROPE_SCALING_LINEAR = 1,
|
||||
LLAMA_ROPE_SCALING_YARN = 2,
|
||||
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
||||
};
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
@ -172,10 +180,16 @@ extern "C" {
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
|
||||
float yarn_attn_factor; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||
|
Loading…
Reference in New Issue
Block a user