mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
ggml : update get_rows f16 and q
This commit is contained in:
parent
ac3f7d8e23
commit
2e4db48291
5
Makefile
5
Makefile
@ -396,6 +396,11 @@ ifdef LLAMA_CUBLAS
|
||||
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
||||
OBJS += ggml-cuda.o
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
|
||||
|
||||
ifdef LLAMA_DEBUG
|
||||
NVCCFLAGS += -lineinfo
|
||||
endif
|
||||
|
||||
ifdef LLAMA_CUDA_NVCC
|
||||
NVCC = $(LLAMA_CUDA_NVCC)
|
||||
else
|
||||
|
39
ggml.c
39
ggml.c
@ -4086,7 +4086,7 @@ struct ggml_tensor * ggml_mul_mat_id(
|
||||
GGML_ASSERT(ids->ne[1] == b->ne[1]);
|
||||
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
||||
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
|
||||
GGML_ASSERT(id >= 0 && id < n_as);
|
||||
GGML_ASSERT(id >= 0 && id < ids->ne[0]);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
@ -10345,7 +10345,7 @@ static void ggml_compute_forward_get_rows_q(
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1);
|
||||
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
||||
@ -10356,14 +10356,16 @@ static void ggml_compute_forward_get_rows_q(
|
||||
assert(ggml_nrows(dst) == nr);
|
||||
|
||||
// TODO: multi-thread
|
||||
for (int64_t i = 0; i < nr; ++i) {
|
||||
const int64_t r = ((int32_t *) src1->data)[i];
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
||||
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
|
||||
dequantize_row_q(
|
||||
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
|
||||
(float *) ((char *) dst->data + i*nb1), nc);
|
||||
dequantize_row_q(
|
||||
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -10381,7 +10383,7 @@ static void ggml_compute_forward_get_rows_f16(
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1);
|
||||
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
||||
|
||||
assert(ne0 == nc);
|
||||
assert(ne02 == ne11);
|
||||
@ -10389,14 +10391,15 @@ static void ggml_compute_forward_get_rows_f16(
|
||||
assert(ggml_nrows(dst) == nr);
|
||||
|
||||
// TODO: multi-thread
|
||||
for (int64_t i = 0; i < nr; ++i) {
|
||||
const int64_t r = ((int32_t *) src1->data)[i];
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
||||
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
|
||||
for (int j = 0; j < nc; ++j) {
|
||||
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
|
||||
((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
|
||||
ggml_fp16_to_fp32_row(
|
||||
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -10415,6 +10418,7 @@ static void ggml_compute_forward_get_rows_f32(
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
||||
|
||||
assert(ne0 == nc);
|
||||
assert(ne02 == ne11);
|
||||
@ -10422,7 +10426,6 @@ static void ggml_compute_forward_get_rows_f32(
|
||||
assert(ggml_nrows(dst) == nr);
|
||||
|
||||
// TODO: multi-thread
|
||||
// TODO: same impl for get_rows_q and get_rows_f16
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
||||
|
Loading…
Reference in New Issue
Block a user