mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
llamafile : improve sgemm.cpp (#6796)
* llamafile : improve sgemm.cpp - Re-enable by default - Fix issue described in #6716 - Make code more abstract, elegant, and maintainable - Faster handling of weirdly shaped `m` an `n` edge cases * Address review comments * Help clang produce fma instructions * Address review comments
This commit is contained in:
parent
e931888d50
commit
192090bae4
@ -43,17 +43,11 @@ else()
|
|||||||
set(LLAMA_METAL_DEFAULT OFF)
|
set(LLAMA_METAL_DEFAULT OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# TODO: fix this for Android CI
|
if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
|
||||||
# https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191
|
|
||||||
#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
|
|
||||||
# set(LLAMA_LLAMAFILE_DEFAULT OFF)
|
|
||||||
#else()
|
|
||||||
# set(LLAMA_LLAMAFILE_DEFAULT ON)
|
|
||||||
#endif()
|
|
||||||
|
|
||||||
# TODO: temporary disable until MoE is fixed
|
|
||||||
# https://github.com/ggerganov/llama.cpp/pull/6716
|
|
||||||
set(LLAMA_LLAMAFILE_DEFAULT OFF)
|
set(LLAMA_LLAMAFILE_DEFAULT OFF)
|
||||||
|
else()
|
||||||
|
set(LLAMA_LLAMAFILE_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
# general
|
# general
|
||||||
option(BUILD_SHARED_LIBS "build shared libraries" OFF)
|
option(BUILD_SHARED_LIBS "build shared libraries" OFF)
|
||||||
|
4
Makefile
4
Makefile
@ -384,10 +384,6 @@ ifdef LLAMA_OPENBLAS
|
|||||||
MK_LDFLAGS += $(shell pkg-config --libs openblas)
|
MK_LDFLAGS += $(shell pkg-config --libs openblas)
|
||||||
endif # LLAMA_OPENBLAS
|
endif # LLAMA_OPENBLAS
|
||||||
|
|
||||||
# TODO: temporary disable until MoE is fixed
|
|
||||||
# https://github.com/ggerganov/llama.cpp/pull/6716
|
|
||||||
LLAMA_NO_LLAMAFILE := 1
|
|
||||||
|
|
||||||
ifndef LLAMA_NO_LLAMAFILE
|
ifndef LLAMA_NO_LLAMAFILE
|
||||||
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
|
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
|
||||||
OBJS += sgemm.o
|
OBJS += sgemm.o
|
||||||
|
8
ggml.c
8
ggml.c
@ -10825,7 +10825,7 @@ static void ggml_compute_forward_mul_mat(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GGML_USE_LLAMAFILE
|
#if GGML_USE_LLAMAFILE
|
||||||
if (nb10 == ggml_type_size(src1->type)) {
|
if (src1_cont) {
|
||||||
for (int64_t i13 = 0; i13 < ne13; i13++)
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
||||||
for (int64_t i12 = 0; i12 < ne12; i12++)
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
||||||
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
||||||
@ -10878,15 +10878,13 @@ UseGgmlGemm1:;
|
|||||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
|
|
||||||
#if GGML_USE_LLAMAFILE
|
#if GGML_USE_LLAMAFILE
|
||||||
if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
|
if (src1->type != vec_dot_type) {
|
||||||
for (int64_t i13 = 0; i13 < ne13; i13++)
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
||||||
for (int64_t i12 = 0; i12 < ne12; i12++)
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
||||||
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
||||||
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
||||||
nb01/ggml_type_size(src0->type),
|
nb01/ggml_type_size(src0->type),
|
||||||
(const char *)wdata + ggml_row_size(vec_dot_type,
|
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
||||||
nb12/ggml_type_size(src1->type)*i12 +
|
|
||||||
nb13/ggml_type_size(src1->type)*i13),
|
|
||||||
row_size/ggml_type_size(vec_dot_type),
|
row_size/ggml_type_size(vec_dot_type),
|
||||||
(char *)dst->data + i12*nb2 + i13*nb3,
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
||||||
nb1/ggml_type_size(dst->type),
|
nb1/ggml_type_size(dst->type),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user