mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
ggml : adjust mul_mat_f16 work memory (#1226)
* llama : minor - remove explicity int64_t cast * ggml : reduce memory buffer for F16 mul_mat when not using cuBLAS * ggml : add asserts to guard for incorrect wsize
This commit is contained in:
parent
305eb5afd5
commit
214b6a3570
9
Makefile
9
Makefile
@ -34,10 +34,15 @@ endif
|
|||||||
#
|
#
|
||||||
|
|
||||||
# keep standard at C11 and C++11
|
# keep standard at C11 and C++11
|
||||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
CFLAGS = -I. -O3 -std=c11 -fPIC
|
||||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
|
||||||
LDFLAGS =
|
LDFLAGS =
|
||||||
|
|
||||||
|
ifndef LLAMA_DEBUG
|
||||||
|
CFLAGS += -DNDEBUG
|
||||||
|
CXXFLAGS += -DNDEBUG
|
||||||
|
endif
|
||||||
|
|
||||||
# warnings
|
# warnings
|
||||||
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith
|
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith
|
||||||
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar
|
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar
|
||||||
|
21
ggml.c
21
ggml.c
@ -8245,8 +8245,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||||
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||||
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||||
#else
|
|
||||||
float * const wdata = params->wdata;
|
|
||||||
#endif
|
#endif
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
@ -8263,8 +8261,11 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(id*sizeof(ggml_fp16_t) <= params->wsize);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
float * const wdata = params->wdata;
|
||||||
{
|
{
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||||
@ -8272,6 +8273,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
|
wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(id*sizeof(float) <= params->wsize);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -8537,7 +8540,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
||||||
id += ne00;
|
id += ne00;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(id*sizeof(float) <= params->wsize);
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -11571,10 +11577,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0));
|
#if defined(GGML_USE_CUBLAS)
|
||||||
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
|
// with cuBLAS, we need memory for the full 3D / 4D data of src1
|
||||||
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
||||||
//printf("cur = %zu\n", cur);
|
#else
|
||||||
|
// here we need memory just for single 2D matrix from src0
|
||||||
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
||||||
}
|
}
|
||||||
|
@ -780,7 +780,7 @@ static bool kv_cache_init(
|
|||||||
const int n_embd = hparams.n_embd;
|
const int n_embd = hparams.n_embd;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
|
|
||||||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
const int64_t n_mem = n_layer*n_ctx;
|
||||||
const int64_t n_elements = n_embd*n_mem;
|
const int64_t n_elements = n_embd*n_mem;
|
||||||
|
|
||||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||||
|
Loading…
Reference in New Issue
Block a user