mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
ggml : add asserts for type conversion in fattn kernels (#9971)
ggml-ci
This commit is contained in:
parent
d5ebd79c76
commit
f594bc80ba
@ -1035,7 +1035,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|||||||
return GGML_TYPE_Q5_1;
|
return GGML_TYPE_Q5_1;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::runtime_error("Invalid cache type: " + s);
|
throw std::runtime_error("Unsupported cache type: " + s);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
||||||
|
@ -324,8 +324,9 @@ struct ggml_logger_state {
|
|||||||
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
|
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
|
||||||
|
|
||||||
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
|
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
|
||||||
if (format == NULL)
|
if (format == NULL) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
va_list args_copy;
|
va_list args_copy;
|
||||||
va_copy(args_copy, args);
|
va_copy(args_copy, args);
|
||||||
char buffer[128];
|
char buffer[128];
|
||||||
@ -15723,6 +15724,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||||
|
|
||||||
|
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
|
||||||
|
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
|
||||||
|
|
||||||
// loop over n_batch and n_head
|
// loop over n_batch and n_head
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// q indices
|
// q indices
|
||||||
|
@ -19243,7 +19243,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
params.flash_attn = false;
|
params.flash_attn = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
|
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
||||||
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user