diff --git a/Makefile b/Makefile index 1aab79df7..3876e93d0 100644 --- a/Makefile +++ b/Makefile @@ -71,21 +71,21 @@ OPT = -Ofast else OPT = -O3 endif -CFLAGS = -I. $(OPT) -std=c11 -fPIC -CXXFLAGS = -I. -I./common $(OPT) -std=c++11 -fPIC -LDFLAGS = +MK_CPPFLAGS = -I. -Icommon +MK_CFLAGS = $(CPPFLAGS) $(OPT) -std=c11 -fPIC +MK_CXXFLAGS = $(CPPFLAGS) $(OPT) -std=c++11 -fPIC +MK_LDFLAGS = ifdef LLAMA_DEBUG - CFLAGS += -O0 -g - CXXFLAGS += -O0 -g - LDFLAGS += -g + MK_CFLAGS += -O0 -g + MK_CXXFLAGS += -O0 -g + MK_LDFLAGS += -g else - CFLAGS += -DNDEBUG - CXXFLAGS += -DNDEBUG + MK_CPPFLAGS += -DNDEBUG endif ifdef LLAMA_SERVER_VERBOSE - CXXFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) + MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) endif ifdef LLAMA_DISABLE_LOGS @@ -94,9 +94,9 @@ ifdef LLAMA_DISABLE_LOGS endif # LLAMA_DISABLE_LOGS # warnings -CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith \ - -Wmissing-prototypes -Werror=implicit-int -Wno-unused-function -CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar +MK_CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith \ + -Wmissing-prototypes -Werror=implicit-int -Wno-unused-function +MK_CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar ifeq '' '$(findstring clang++,$(CXX))' # g++ only @@ -105,29 +105,9 @@ endif # OS specific # TODO: support Windows -ifeq ($(UNAME_S),Linux) - CFLAGS += -pthread - CXXFLAGS += -pthread -endif -ifeq ($(UNAME_S),Darwin) - CFLAGS += -pthread - CXXFLAGS += -pthread -endif -ifeq ($(UNAME_S),FreeBSD) - CFLAGS += -pthread - CXXFLAGS += -pthread -endif -ifeq ($(UNAME_S),NetBSD) - CFLAGS += -pthread - CXXFLAGS += -pthread -endif -ifeq ($(UNAME_S),OpenBSD) - CFLAGS += -pthread - CXXFLAGS += -pthread -endif -ifeq ($(UNAME_S),Haiku) - CFLAGS += -pthread - CXXFLAGS += -pthread +ifneq '' '$(filter $(UNAME_S),Linux Darwin FreeBSD NetBSD OpenBSD Haiku)' + MK_CFLAGS += -pthread + MK_CXXFLAGS += -pthread endif # detect Windows @@ -153,12 +133,11 @@ ifeq ($(_WIN32),1) endif ifdef LLAMA_GPROF - CFLAGS += -pg - CXXFLAGS += -pg + MK_CFLAGS += -pg + MK_CXXFLAGS += -pg endif ifdef LLAMA_PERF - CFLAGS += -DGGML_PERF - CXXFLAGS += -DGGML_PERF + MK_CPPFLAGS += -DGGML_PERF endif # Architecture specific @@ -169,16 +148,16 @@ ifndef RISCV ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64)) # Use all CPU extensions that are available: - CFLAGS += -march=native -mtune=native - CXXFLAGS += -march=native -mtune=native + MK_CFLAGS += -march=native -mtune=native + MK_CXXFLAGS += -march=native -mtune=native # Usage AVX-only - #CFLAGS += -mfma -mf16c -mavx - #CXXFLAGS += -mfma -mf16c -mavx + #MK_CFLAGS += -mfma -mf16c -mavx + #MK_CXXFLAGS += -mfma -mf16c -mavx # Usage SSSE3-only (Not is SSE3!) - #CFLAGS += -mssse3 - #CXXFLAGS += -mssse3 + #MK_CFLAGS += -mssse3 + #MK_CXXFLAGS += -mssse3 endif # The stack is only 16-byte aligned on Windows, so don't let gcc emit aligned moves. @@ -192,34 +171,33 @@ endif ifneq ($(filter aarch64%,$(UNAME_M)),) # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) - CFLAGS += -mcpu=native - CXXFLAGS += -mcpu=native + MK_CFLAGS += -mcpu=native + MK_CXXFLAGS += -mcpu=native endif ifneq ($(filter armv6%,$(UNAME_M)),) # Raspberry Pi 1, Zero - CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access + MK_CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access + MK_CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access endif ifneq ($(filter armv7%,$(UNAME_M)),) # Raspberry Pi 2 - CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations + MK_CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations + MK_CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations endif ifneq ($(filter armv8%,$(UNAME_M)),) # Raspberry Pi 3, 4, Zero 2 (32-bit) - CFLAGS += -mfp16-format=ieee -mno-unaligned-access + MK_CFLAGS += -mfp16-format=ieee -mno-unaligned-access + MK_CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access endif ifneq ($(filter ppc64%,$(UNAME_M)),) POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) ifneq (,$(findstring POWER9,$(POWER9_M))) - CFLAGS += -mcpu=power9 - CXXFLAGS += -mcpu=power9 - endif - # Require c++23's std::byteswap for big-endian support. - ifeq ($(UNAME_M),ppc64) - CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN + MK_CFLAGS += -mcpu=power9 + MK_CXXFLAGS += -mcpu=power9 endif endif @@ -229,12 +207,10 @@ else endif ifndef LLAMA_NO_K_QUANTS - CFLAGS += -DGGML_USE_K_QUANTS - CXXFLAGS += -DGGML_USE_K_QUANTS + MK_CPPFLAGS += -DGGML_USE_K_QUANTS OBJS += k_quants.o ifdef LLAMA_QKK_64 - CFLAGS += -DGGML_QKK_64 - CXXFLAGS += -DGGML_QKK_64 + MK_CPPFLAGS += -DGGML_QKK_64 endif endif @@ -242,8 +218,8 @@ ifndef LLAMA_NO_ACCELERATE # Mac OS - include Accelerate framework. # `-framework Accelerate` works both with Apple Silicon and Mac Intel ifeq ($(UNAME_S),Darwin) - CFLAGS += -DGGML_USE_ACCELERATE - LDFLAGS += -framework Accelerate + MK_CPPFLAGS += -DGGML_USE_ACCELERATE + MK_LDFLAGS += -framework Accelerate endif endif # LLAMA_NO_ACCELERATE @@ -258,25 +234,26 @@ ifndef LLAMA_NO_METAL endif # LLAMA_NO_METAL ifdef LLAMA_MPI - CFLAGS += -DGGML_USE_MPI -Wno-cast-qual - CXXFLAGS += -DGGML_USE_MPI -Wno-cast-qual + MK_CPPFLAGS += -DGGML_USE_MPI + MK_CFLAGS += -Wno-cast-qual + MK_CXXFLAGS += -Wno-cast-qual OBJS += ggml-mpi.o endif # LLAMA_MPI ifdef LLAMA_OPENBLAS - CFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags openblas) - LDFLAGS += $(shell pkg-config --libs openblas) + MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas) + MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas) + MK_LDFLAGS += $(shell pkg-config --libs openblas) endif # LLAMA_OPENBLAS ifdef LLAMA_BLIS - CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis - LDFLAGS += -lblis -L/usr/local/lib + MK_CPPFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis + MK_LDFLAGS += -lblis -L/usr/local/lib endif # LLAMA_BLIS ifdef LLAMA_CUBLAS - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib + MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include + MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math ifdef LLAMA_CUDA_NVCC @@ -327,14 +304,15 @@ endif # LLAMA_CUBLAS ifdef LLAMA_CLBLAST - CFLAGS += -DGGML_USE_CLBLAST $(shell pkg-config --cflags clblast OpenCL) - CXXFLAGS += -DGGML_USE_CLBLAST $(shell pkg-config --cflags clblast OpenCL) + MK_CPPFLAGS += -DGGML_USE_CLBLAST $(shell pkg-config --cflags-only-I clblast OpenCL) + MK_CFLAGS += $(shell pkg-config --cflags-only-other clblast OpenCL) + MK_CXXFLAGS += $(shell pkg-config --cflags-only-other clblast OpenCL) # Mac provides OpenCL as a framework ifeq ($(UNAME_S),Darwin) - LDFLAGS += -lclblast -framework OpenCL + MK_LDFLAGS += -lclblast -framework OpenCL else - LDFLAGS += $(shell pkg-config --libs clblast OpenCL) + MK_LDFLAGS += $(shell pkg-config --libs clblast OpenCL) endif OBJS += ggml-opencl.o @@ -349,10 +327,9 @@ ifdef LLAMA_HIPBLAS LLAMA_CUDA_DMMV_X ?= 32 LLAMA_CUDA_MMV_Y ?= 1 LLAMA_CUDA_KQUANTS_ITER ?= 2 - CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS - CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS - LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib - LDFLAGS += -lhipblas -lamdhip64 -lrocblas + MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS + MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib + MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) @@ -366,6 +343,12 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< endif # LLAMA_HIPBLAS +ifndef LLAMA_NO_METAL + MK_CPPFLAGS += -DGGML_USE_METAL #-DGGML_METAL_NDEBUG + MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit + OBJS += ggml-metal.o +endif # LLAMA_METAL + ifndef LLAMA_NO_METAL ggml-metal.o: ggml-metal.m ggml-metal.h $(CC) $(CFLAGS) -c $< -o $@ @@ -376,11 +359,17 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h $(CC) $(CFLAGS) -c $< -o $@ endif # LLAMA_MPI -ifdef LLAMA_NO_K_QUANTS +ifndef LLAMA_NO_K_QUANTS k_quants.o: k_quants.c k_quants.h $(CC) $(CFLAGS) -c $< -o $@ endif # LLAMA_NO_K_QUANTS +# combine build flags with cmdline overrides +override CPPFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) +override CFLAGS := $(MK_CFLAGS) $(CFLAGS) +override CXXFLAGS := $(MK_CXXFLAGS) $(CXXFLAGS) +override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS) + # # Print build information # diff --git a/Package.swift b/Package.swift index 73d027c70..96f52c4f0 100644 --- a/Package.swift +++ b/Package.swift @@ -12,9 +12,18 @@ let package = Package( name: "llama", path: ".", exclude: ["ggml-metal.metal"], - sources: ["ggml.c", "llama.cpp"], + sources: [ + "ggml.c", + "llama.cpp", + "ggml-alloc.c", + "k_quants.c" + ], publicHeadersPath: "spm-headers", - cSettings: [.unsafeFlags(["-Wno-shorten-64-to-32"]), .define("GGML_USE_ACCELERATE")], + cSettings: [ + .unsafeFlags(["-Wno-shorten-64-to-32"]), + .define("GGML_USE_K_QUANTS"), + .define("GGML_USE_ACCELERATE") + ], linkerSettings: [ .linkedFramework("Accelerate") ] diff --git a/README.md b/README.md index 4e6a0957d..0cfd94db4 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,7 @@ as the main playground for developing new features for the [ggml](https://github - [nat/openplayground](https://github.com/nat/openplayground) - [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) +- [withcatai/catai](https://github.com/withcatai/catai) --- @@ -464,6 +465,8 @@ Building the program with BLAS support may lead to some performance improvements You will need the [OpenCL SDK](https://github.com/KhronosGroup/OpenCL-SDK). - For Ubuntu or Debian, the packages `opencl-headers`, `ocl-icd` may be needed. + - For Windows, a pre-built SDK is available on the [OpenCL Releases](https://github.com/KhronosGroup/OpenCL-SDK/releases) page. + -
Installing the OpenCL SDK from source @@ -481,10 +484,27 @@ Building the program with BLAS support may lead to some performance improvements ```
- Installing CLBlast: it may be found in your operating system's packages. + ##### Installing CLBlast + + Pre-built CLBlast binaries may be found on the [CLBlast Releases](https://github.com/CNugteren/CLBlast/releases) page. For Unix variants, it may also be found in your operating system's packages. + + Alternatively, they may be built from source. -
- If not, then installing from source: + Windows: + + ```cmd + set OPENCL_SDK_ROOT="C:/OpenCL-SDK-v2023.04.17-Win-x64" + git clone https://github.com/CNugteren/CLBlast.git + mkdir CLBlast\build + cd CLBlast\build + cmake .. -DBUILD_SHARED_LIBS=OFF -DOVERRIDE_MSVC_FLAGS_TO_MT=OFF -DTUNERS=OFF -DOPENCL_ROOT=%OPENCL_SDK_ROOT% -G "Visual Studio 17 2022" -A x64 + cmake --build . --config Release + cmake --install . --prefix C:/CLBlast + ``` + + -
+ Unix: ```sh git clone https://github.com/CNugteren/CLBlast.git @@ -498,21 +518,32 @@ Building the program with BLAS support may lead to some performance improvements Where `/some/path` is where the built library will be installed (default is `/usr/local`).
- Building: + ##### Building Llama with CLBlast - Build with make: ```sh make LLAMA_CLBLAST=1 ``` - - CMake: + - CMake (Unix): ```sh mkdir build cd build cmake .. -DLLAMA_CLBLAST=ON -DCLBlast_dir=/some/path cmake --build . --config Release ``` + - CMake (Windows): + ```cmd + set CL_BLAST_CMAKE_PKG="C:/CLBlast/lib/cmake/CLBlast" + git clone https://github.com/ggerganov/llama.cpp + cd llama.cpp + mkdir build + cd build + cmake .. -DBUILD_SHARED_LIBS=OFF -DLLAMA_CLBLAST=ON -DCMAKE_PREFIX_PATH=%CL_BLAST_CMAKE_PKG% -G "Visual Studio 17 2022" -A x64 + cmake --build . --config Release + cmake --install . --prefix C:/LlamaCPP + ``` - Running: + ##### Running Llama with CLBlast The CLBlast build supports `--gpu-layers|-ngl` like the CUDA version does. diff --git a/common/log.h b/common/log.h index bf9fafd68..0b9b01052 100644 --- a/common/log.h +++ b/common/log.h @@ -341,14 +341,14 @@ inline FILE *log_handler1_impl(bool change = false, LogTriState disable = LogTri } } + if (_disabled) + { + // Log is disabled + return nullptr; + } + if (_initialized) { - if (_disabled) - { - // Log is disabled - return nullptr; - } - // with fallback in case something went wrong return logfile ? logfile : stderr; } diff --git a/convert.py b/convert.py index 6c89b5ecc..5a7483b43 100755 --- a/convert.py +++ b/convert.py @@ -323,15 +323,27 @@ class BpeVocab: self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) added_tokens: dict[str, int] if fname_added_tokens is not None: + # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: - added_tokens = {} + # Fall back to trying to find the added tokens in tokenizer.json + tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json' + if not tokenizer_json_file.is_file(): + added_tokens = {} + else: + tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) + added_tokens = dict( + (item['content'], item['id']) + for item in tokenizer_json.get('added_tokens', []) + # Added tokens here can be duplicates of the main vocabulary. + if item['content'] not in self.bpe_tokenizer ) vocab_size: int = len(self.bpe_tokenizer) expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: - raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}") + expected_end_id = vocab_size + len(actual_ids) - 1 + raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}") items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) self.added_tokens_list = [text for (text, idx) in items] @@ -345,10 +357,22 @@ class BpeVocab: from transformers.models.gpt2 import tokenization_gpt2 # type: ignore[import] byte_encoder = tokenization_gpt2.bytes_to_unicode() byte_decoder = {v: k for k, v in byte_encoder.items()} + score = 0.0 for i, item in enumerate(tokenizer): text: bytes = item.encode("utf-8") - score: float = -i - yield text, score, gguf.TokenType.USER_DEFINED + # FIXME: These shouldn't be hardcoded, but it's probably better than the current behavior? + if i <= 258 and text.startswith(b'<') and text.endswith(b'>'): + if i == 0 and text == b'': + toktype = gguf.TokenType.UNKNOWN + elif i == 1 or i == 2: + toktype = gguf.TokenType.CONTROL + elif i >= 3 and text.startswith(b'<0x'): + toktype = gguf.TokenType.BYTE + else: + toktype = gguf.TokenType.NORMAL + else: + toktype = gguf.TokenType.NORMAL + yield text, score, toktype def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: for text in self.added_tokens_list: diff --git a/examples/gptneox-wip/gptneox-main.cpp b/examples/gptneox-wip/gptneox-main.cpp index 04af50245..6291523f2 100644 --- a/examples/gptneox-wip/gptneox-main.cpp +++ b/examples/gptneox-wip/gptneox-main.cpp @@ -660,9 +660,10 @@ bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt2 ggml_tensor * gpt_neox_ff( const gpt_neox_block &block, ggml_context * ctx0, - ggml_tensor * inp) { + ggml_tensor * inp, + const gpt_neox_hparams &hparams) { - ggml_tensor * cur = ggml_norm(ctx0, inp); + ggml_tensor * cur = ggml_norm(ctx0, inp, hparams.norm_eps); cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, block.ln_2_g, cur), cur), ggml_repeat(ctx0, block.ln_2_b, cur)); cur = ggml_mul_mat(ctx0, block.c_mlp_fc_w, cur); @@ -753,7 +754,7 @@ bool gpt_neox_eval( // self-attention { { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.norm_eps); cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.blocks[il].ln_1_g, cur), cur), @@ -844,7 +845,7 @@ bool gpt_neox_eval( if (hparams.par_res == 0) { struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL); - cur = gpt_neox_ff(model.blocks[il], ctx0, inpFF); + cur = gpt_neox_ff(model.blocks[il], ctx0, inpFF, hparams); // input for next layer inpL = ggml_add(ctx0, cur, inpFF); @@ -853,7 +854,7 @@ bool gpt_neox_eval( // this is independent of the self-attention result, so it could be done in parallel to the self-attention // note here we pass inpL instead of cur - cur = gpt_neox_ff(model.blocks[il], ctx0, inpL); + cur = gpt_neox_ff(model.blocks[il], ctx0, inpL, hparams); // layer input + FF cur = ggml_add(ctx0, cur, inpFF); @@ -867,7 +868,7 @@ bool gpt_neox_eval( // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, hparams.norm_eps); // inpL = ln_f_g*inpL + ln_f_b inpL = ggml_add(ctx0, diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 09eac2ec2..94def943b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1379,7 +1379,13 @@ int main(int argc, char **argv) } } - const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs); + auto probs = llama.generated_token_probs; + if (llama.params.n_probs > 0 && llama.stopped_word) { + const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); + probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); + } + + const json data = format_final_response(llama, llama.generated_text, probs); llama_print_timings(llama.ctx); @@ -1456,7 +1462,11 @@ int main(int argc, char **argv) if (!llama.has_next_token) { // Generation is done, send extra information. - const json data = format_final_response(llama, "", llama.generated_token_probs); + const json data = format_final_response( + llama, + "", + std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) + ); const std::string str = "data: " + diff --git a/ggml-alloc.c b/ggml-alloc.c index f07a4a217..459f121ca 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -284,7 +284,14 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) // address and size of the buffer when measuring // it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers static void * const MEASURE_BASE_ADDR = (void *) 0x1000; +#if defined(__ARM_NEON) && !defined(__aarch64__) +// 32-bit +// TODO: Use for 32-bit x86 as well +static const size_t MEASURE_MAX_SIZE = (1ULL<<32) - 1; // 4 GB +#else +// 64-bit static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB +#endif struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5fd625630..8357f32f7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -81,12 +81,29 @@ #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { const int8x4_t va = reinterpret_cast(a); const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) } static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { diff --git a/ggml-metal.m b/ggml-metal.m index 4267db9be..88e7e1356 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -116,10 +116,24 @@ static NSString * const msl_library_source = @"see metal.metal"; struct ggml_metal_context * ggml_metal_init(int n_cb) { metal_printf("%s: allocating\n", __func__); - struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); + // Show all the Metal device instances in the system + NSArray * devices = MTLCopyAllDevices(); + id device; + NSString * s; + for (device in devices) { + s = [device name]; + metal_printf("%s: found device: %s\n", __func__, [s UTF8String]); + } + // Pick and show default Metal device + device = MTLCreateSystemDefaultDevice(); + s = [device name]; + metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]); + + // Configure context + struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); + ctx->device = device; ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); - ctx->device = MTLCreateSystemDefaultDevice(); ctx->queue = [ctx->device newCommandQueue]; ctx->n_buffers = 0; ctx->concur_list_len = 0; diff --git a/ggml.c b/ggml.c index cf3955f7f..38b1155c1 100644 --- a/ggml.c +++ b/ggml.c @@ -817,46 +817,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #if !defined(__aarch64__) -inline static uint16_t vaddvq_u8(uint8x16_t v) { - return - (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) + - (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) + - (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) + - (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) + - (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) + - (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) + - (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) + - (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15); -} - -inline static int16_t vaddvq_s8(int8x16_t v) { - return - (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) + - (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) + - (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) + - (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) + - (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) + - (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) + - (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) + - (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15); -} - -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static uint32_t vaddvq_u16(uint16x8_t v) { - return - (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) + - (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) + - (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) + - (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7); -} - inline static int32_t vaddvq_s32(int32x4_t v) { return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); } @@ -865,12 +825,6 @@ inline static float vaddvq_f32(float32x4_t v) { return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); } -inline static float vminvq_f32(float32x4_t v) { - return - MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - inline static float vmaxvq_f32(float32x4_t v) { return MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), diff --git a/k_quants.c b/k_quants.c index 3deeaedf7..4accd2480 100644 --- a/k_quants.c +++ b/k_quants.c @@ -13,6 +13,26 @@ // #include +#if !defined(__aarch64__) +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} +#endif + #else #ifdef __wasm_simd128__ @@ -1302,7 +1322,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m3 = vdupq_n_u8(0x3); const uint8x16_t m4 = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) const int32x4_t vzero = vdupq_n_s32(0); +#endif int8x16x2_t q2bytes; uint8_t aux[16]; @@ -1608,7 +1630,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #ifdef __ARM_NEON const uint8x16_t m3 = vdupq_n_u8(0x3); +#if defined(__ARM_FEATURE_DOTPROD) const int32x4_t vzero = vdupq_n_s32(0); +#endif int8x16x4_t q2bytes; @@ -2592,8 +2616,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - //int32x4_t isum = mzero; - int32_t sumi1 = 0; int32_t sumi2 = 0; @@ -3092,9 +3114,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); - const int32x4_t mzero = vdupq_n_s32(0); const uint8x16_t mone = vdupq_n_u8(1); const uint8x16_t mtwo = vdupq_n_u8(2); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif int8x16x4_t q5bytes; @@ -3437,8 +3461,10 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); - const int32x4_t mzero = vdupq_n_s32(0); const uint8x16_t mh = vdupq_n_u8(16); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif int8x16x4_t q5bytes; uint8x16x4_t q5h; @@ -3656,7 +3682,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri float sum = 0; const uint8x16_t m4b = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) const int32x4_t vzero = vdupq_n_s32(0); +#endif //const int8x16_t m32s = vdupq_n_s8(32); const uint8x16_t mone = vdupq_n_u8(3); @@ -4045,8 +4073,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri float sum = 0; const uint8x16_t m4b = vdupq_n_u8(0xF); - const int32x4_t vzero = vdupq_n_s32(0); const int8x16_t m32s = vdupq_n_s8(32); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif const uint8x16_t mone = vdupq_n_u8(3); diff --git a/llama.cpp b/llama.cpp index 3114d3311..2b0cf30f6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -325,6 +325,44 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_GPT2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTJ, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTNEOX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MPT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, }; static llm_arch llm_arch_from_string(const std::string & name) { @@ -1605,9 +1643,13 @@ static void llm_load_hparams( GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); - if (hparams.n_rot != hparams.n_embd / hparams.n_head) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd / hparams.n_head) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + } } + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim } // arch-specific KVs