mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
ggml : add ggml_row_size() (fixes llama out of space) (#4461)
* Fixes "Not enough space in the context's memory pool" encountered on certain models, which seems to be caused by some imprecision related to the automatic casting of floating point values * do not cast to size_t, instead just use doubles * ggml : add ggml_row_size(), deprecate ggml_type_sizef() * ggml : fix row size compute to avoid overflows * tests : fix sizey -> sizez --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
55e87c3749
commit
20a68a7030
@ -129,13 +129,13 @@ int main(int argc, char ** argv) {
|
||||
const ggml_type qtype = GGML_TYPE_Q4_1;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32);
|
||||
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32);
|
||||
ctx_size += sizex*sizez*ggml_type_sizef(GGML_TYPE_F32);
|
||||
ctx_size += sizex*sizey*ggml_type_sizef(qtype);
|
||||
ctx_size += sizex*sizey*ggml_type_sizef(qtype);
|
||||
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS
|
||||
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizez);
|
||||
ctx_size += ggml_row_size(qtype, sizex*sizey);
|
||||
ctx_size += ggml_row_size(qtype, sizex*sizey);
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
|
||||
ctx_size += 1024*1024*16;
|
||||
|
||||
printf("Allocating Memory of size %zi bytes, %zi MB\n",ctx_size, (ctx_size/1024/1024));
|
||||
|
9
ggml.c
9
ggml.c
@ -2011,8 +2011,13 @@ size_t ggml_type_size(enum ggml_type type) {
|
||||
return type_traits[type].type_size;
|
||||
}
|
||||
|
||||
float ggml_type_sizef(enum ggml_type type) {
|
||||
return ((float)(type_traits[type].type_size))/type_traits[type].blck_size;
|
||||
size_t ggml_row_size(enum ggml_type type, int64_t ne) {
|
||||
assert(ne % ggml_blck_size(type) == 0);
|
||||
return ggml_type_size(type)*ne/ggml_blck_size(type);
|
||||
}
|
||||
|
||||
double ggml_type_sizef(enum ggml_type type) {
|
||||
return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
|
||||
}
|
||||
|
||||
const char * ggml_type_name(enum ggml_type type) {
|
||||
|
6
ggml.h
6
ggml.h
@ -643,7 +643,11 @@ extern "C" {
|
||||
|
||||
GGML_API int ggml_blck_size(enum ggml_type type);
|
||||
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
||||
GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
|
||||
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
||||
|
||||
GGML_DEPRECATED(
|
||||
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
||||
"use ggml_row_size() instead");
|
||||
|
||||
GGML_API const char * ggml_type_name(enum ggml_type type);
|
||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||
|
12
llama.cpp
12
llama.cpp
@ -1555,7 +1555,7 @@ static bool llama_kv_cache_init(
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(n_ctx);
|
||||
|
||||
cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
|
||||
cache.buf.resize(ggml_row_size(ktype, n_elements) + ggml_row_size(vtype, n_elements) + 2u*n_layer*ggml_tensor_overhead());
|
||||
memset(cache.buf.data, 0, cache.buf.size);
|
||||
|
||||
struct ggml_init_params params;
|
||||
@ -3822,8 +3822,8 @@ static void llm_build_k_shift(
|
||||
ggml_rope_custom_inplace(ctx,
|
||||
ggml_view_3d(ctx, kv.k_l[il],
|
||||
n_embd_head, n_head_kv, n_ctx,
|
||||
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
|
||||
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
|
||||
ggml_row_size(kv.k_l[il]->type, n_embd_head),
|
||||
ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
|
||||
0),
|
||||
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
@ -3852,7 +3852,7 @@ static void llm_build_kv_store(
|
||||
cb(v_cur_t, "v_cur_t", il);
|
||||
|
||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
|
||||
(ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
|
||||
(ggml_row_size(kv.k_l[il]->type, n_embd_gqa))*kv_head);
|
||||
cb(k_cache_view, "k_cache_view", il);
|
||||
|
||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
|
||||
@ -4011,8 +4011,8 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx, kv.k_l[il],
|
||||
n_embd_head, n_kv, n_head_kv,
|
||||
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
|
||||
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
|
||||
ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
|
||||
ggml_row_size(kv.k_l[il]->type, n_embd_head),
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user