From 0d2b66c638cf56c46c152dbc1fe1709672650d05 Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 10 Jul 2023 17:32:06 +0200 Subject: [PATCH] ggml backend interface wip refactor ggml-cuda --- .github/workflows/build.yml | 8 +- CMakeLists.txt | 10 +- Makefile | 29 +- examples/simple/simple.cpp | 229 +- ggml-backend.c | 435 ++++ ggml-backend.h | 129 + ggml-cuda-kern.h | 468 ++++ ggml-cuda-quant.h | 920 +++++++ ggml-cuda.cu | 4831 ++++++++++------------------------- ggml-cuda.h | 23 +- ggml.c | 526 ++-- ggml.h | 64 +- llama-util.h | 41 +- llama.cpp | 1542 ++++++----- llama.h | 7 +- 15 files changed, 4480 insertions(+), 4782 deletions(-) create mode 100644 ggml-backend.c create mode 100644 ggml-backend.h create mode 100644 ggml-cuda-kern.h create mode 100644 ggml-cuda-quant.h diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b6e21b4ec..aa0913f61 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -308,13 +308,13 @@ jobs: path: | llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip - windows-latest-cmake-cublas: + windows-latest-cmake-cuda: runs-on: windows-latest strategy: matrix: cuda: ['12.1.0', '11.7.1'] - build: ['cublas'] + build: ['cuda'] steps: - name: Clone @@ -333,7 +333,7 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUBLAS=ON + cmake .. -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUDA=ON cmake --build . --config Release - name: Get commit hash @@ -395,7 +395,7 @@ jobs: - macOS-latest-make - macOS-latest-cmake - windows-latest-cmake - - windows-latest-cmake-cublas + - windows-latest-cmake-cuda steps: - name: Download artifacts diff --git a/CMakeLists.txt b/CMakeLists.txt index d9381dae1..1930a905a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,7 @@ endif() option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") @@ -239,18 +239,18 @@ if (LLAMA_K_QUANTS) endif() endif() -if (LLAMA_CUBLAS) +if (LLAMA_CUDA) cmake_minimum_required(VERSION 3.17) find_package(CUDAToolkit) if (CUDAToolkit_FOUND) - message(STATUS "cuBLAS found") + message(STATUS "CUDA found") enable_language(CUDA) set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) - add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(GGML_USE_CUDA) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) endif() @@ -280,7 +280,7 @@ if (LLAMA_CUBLAS) message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() - message(WARNING "cuBLAS not found") + message(WARNING "CUDA not found") endif() endif() diff --git a/Makefile b/Makefile index 6c74e1346..60c8922d4 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,12 @@ else CXXFLAGS += -DNDEBUG endif +ifdef LLAMA_SANITIZE + CFLAGS += -g -fsanitize=$(LLAMA_SANITIZE) -fno-omit-frame-pointer + CXXFLAGS += -g -fsanitize=$(LLAMA_SANITIZE) -fno-omit-frame-pointer + LDFLAGS += -g -fsanitize=$(LLAMA_SANITIZE) +endif + ifdef LLAMA_SERVER_VERBOSE CXXFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) endif @@ -163,13 +169,16 @@ ifdef LLAMA_BLIS LDFLAGS += -lblis -L/usr/local/lib endif # LLAMA_BLIS -ifdef LLAMA_CUBLAS - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include +ifdef LLAMA_CUDA + CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include + CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler +ifdef LLAMA_DEBUG + NVCCFLAGS += -lineinfo +endif # LLAMA_DEBUG ifdef CUDA_DOCKER_ARCH NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) else @@ -198,10 +207,9 @@ ifdef LLAMA_CUDA_KQUANTS_ITER else NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 endif - -ggml-cuda.o: ggml-cuda.cu ggml-cuda.h +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-cuda-kern.h ggml-cuda-quant.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ -endif # LLAMA_CUBLAS +endif # LLAMA_CUDA ifdef LLAMA_CLBLAST CFLAGS += -DGGML_USE_CLBLAST @@ -275,6 +283,9 @@ $(info I CXXFLAGS: $(CXXFLAGS)) $(info I LDFLAGS: $(LDFLAGS)) $(info I CC: $(CCV)) $(info I CXX: $(CXXV)) +ifdef LLAMA_CUDA +$(info I NVCC: $(NVCCV)) +endif # LLAMA_CUDA $(info ) # @@ -284,6 +295,12 @@ $(info ) ggml.o: ggml.c ggml.h ggml-cuda.h $(CC) $(CFLAGS) -c $< -o $@ +# temporary, probably will be added to ggml.c +ggml-backend.o: ggml-backend.c ggml-backend.h ggml.h + $(CC) $(CFLAGS) -c $< -o $@ + +OBJS += ggml-backend.o + llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h $(CXX) $(CXXFLAGS) -c $< -o $@ diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index aa2c4352d..a4046302e 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -1,46 +1,14 @@ -#ifndef _GNU_SOURCE -#define _GNU_SOURCE -#endif - -#include "common.h" -#include "llama.h" -#include "build-info.h" - -#include -#include -#include -#include -#include -#include -#include -#include +#include #include #include -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) -#include -#include -#elif defined (_WIN32) -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#include -#endif +#include "llama.h" - -int main(int argc, char ** argv) -{ - gpt_params params; - - //--------------------------------- - // Print help : - //--------------------------------- - - if ( argc == 1 || argv[1][0] == '-' ) - { - printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] ); - return 1 ; +void generate_sequence(llama_context * ctx, int n_ctx, const std::vector& prompt_tokens, float temperature) { + // print the tokens from the prompt + for (llama_token id : prompt_tokens) { + printf("%s", llama_token_to_str(ctx, id)); } //--------------------------------- @@ -107,75 +75,164 @@ int main(int argc, char ** argv) fflush(stdout); + // the maximum number of tokens to generate at a time + // TODO: not supported, remove + const int CUDA_MAX_TOKENS = 1; + llama_token tokens_out[CUDA_MAX_TOKENS]; - //--------------------------------- - // Main prediction loop : - //--------------------------------- + // current position in the context window + int n_past = 0; - // The LLM keeps a contextual cache memory of previous token evaluation. - // Usually, once this cache is full, it is required to recompute a compressed context based on previous - // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist - // example, we will just stop the loop once this cache is full or once an end of stream is detected. + // number of tokens to generate + int n_tokens_out; - while ( llama_get_kv_cache_token_count( ctx ) < max_context_size ) - { - //--------------------------------- - // Evaluate the tokens : - //--------------------------------- + // list of tokens to evaluate + // note that at most llama_context_params::n_batch tokens can be evaluated at a time + std::vector token_list = prompt_tokens; - if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) ) - { - fprintf( stderr, "%s : failed to eval\n" , __func__ ); - return 1; + while (n_past < n_ctx) { + // evaluate the tokens + + // llama_eval generates one token at a time + n_tokens_out = 1; + + // number of threads to use for CPU evaluation - ignored if compiled with CUDA support + const int n_threads = 4; + // note: llama_eval is not compatible with GPU sampling + if (llama_eval(ctx, token_list.data(), token_list.size(), n_past, n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__ ); + exit(1); } - tokens_list.clear(); - - //--------------------------------- - // Select the best prediction : - //--------------------------------- - - llama_token new_token_id = 0; - - auto logits = llama_get_logits( ctx ); - auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens) + // perform sampling on the CPU + float * logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + // initialize candidate array from logits std::vector candidates; - candidates.reserve( n_vocab ); - - for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ ) - { - candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } ); + candidates.reserve(n_vocab); + for(llama_token token_id = 0 ; token_id < n_vocab ; token_id++) { + candidates.push_back(llama_token_data{ token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - // Select it using the "Greedy sampling" method : - new_token_id = llama_sample_token_greedy( ctx , &candidates_p ); + // sample token + llama_sample_temperature(ctx, &candidates_p, temperature); + tokens_out[0] = llama_sample_token(ctx, &candidates_p); + // increment the position in the context window + n_past += token_list.size() + n_tokens_out - 1; - // is it an end of stream ? - if ( new_token_id == llama_token_eos() ) - { - fprintf(stderr, " [end of text]\n"); - break; + token_list.clear(); + + // print the new tokens + for (int i = 0; i < n_tokens_out; i++) { + llama_token new_token_id = tokens_out[i]; + + // is it an end of stream ? + if (new_token_id == llama_token_eos()) { + fprintf(stderr, " [end of text]\n"); + //return; + } + + // print the new token : + printf("%s", llama_token_to_str(ctx, new_token_id)); } + fflush(stdout); - // Print the new token : - printf( "%s" , llama_token_to_str( ctx , new_token_id ) ); - fflush( stdout ); + // push the last new token for the next evaluation + token_list.push_back(tokens_out[n_tokens_out - 1]); + } +} - // Push this new token for next evaluation : - tokens_list.push_back( new_token_id ); +int main(int argc, char ** argv) { + if (argc < 2 || argv[1][0] == '-') { + printf("usage: %s [prompt]\n", argv[0]); + printf(" note: passing a temp parameter will enable GPU sampling\n"); + return 1 ; + } - } // wend of main loop + std::string model = argv[1]; + struct llama_context_params lparams = llama_context_default_params(); - llama_free( ctx ); - llama_free_model( model ); + if (argc >= 3) { + lparams.n_ctx = std::stoi(argv[2]); + } else { + lparams.n_ctx = 512; + } + + int n_gens; + if (argc >= 4) { + n_gens = std::stoi(argv[3]); + } else { + n_gens = 1; + } + + float temperature; + + if (argc >= 5) { + temperature = std::stof(argv[4]); + } else { + temperature = 0.8f; + } + + std::string prompt; + if (argc >= 6) { + prompt = argv[5]; + } else { + prompt = "Hello my name is"; + } + + // initialize llama.cpp + bool numa = false; + llama_init_backend(numa); + + llama_model * lmodel = llama_load_model_from_file(model.c_str(), lparams); + if (lmodel == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, model.c_str()); + return 1; + } + + llama_context * ctx = llama_new_context_with_model(lmodel, lparams); + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, model.c_str()); + llama_free_model(lmodel); + return 1; + } + + // tokenize the prompt + std::vector token_list(lparams.n_ctx); + int prompt_tokens = llama_tokenize(ctx, prompt.c_str(), token_list.data(), token_list.size(), true); + if (prompt_tokens <= 0) { + fprintf(stderr, "%s: error: unable to tokenize prompt\n", __func__); + return 1; + } + + token_list.resize(prompt_tokens); + + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4 ; + + if ((int)token_list.size() > max_tokens_list_size) { + fprintf( stderr, "%s: error: prompt too long (%d tokens, max %d)\n" , + __func__, (int)token_list.size(), max_tokens_list_size ); + return 1; + } + + fprintf(stderr, "\n\n"); + + // generate the sequences + for (int i = 0; i < n_gens; i++) { + printf("==== GENERATION %d ====\n", i + 1); + generate_sequence(ctx, max_context_size, token_list, temperature); + printf("\n\n"); + } + + llama_print_timings(ctx); + llama_free(ctx); llama_backend_free(); return 0; } - -// EOF diff --git a/ggml-backend.c b/ggml-backend.c new file mode 100644 index 000000000..23a0c48c2 --- /dev/null +++ b/ggml-backend.c @@ -0,0 +1,435 @@ +#include "ggml-backend.h" +#include +#include +#include +#include +#include + +#define UNUSED(x) (void)(x) + +// backend buffer + +struct ggml_buffer ggml_backend_alloc_buffer(struct ggml_backend * backend, size_t size, size_t max_tensors) { + struct ggml_buffer buffer; + buffer.mem_size = ggml_tensor_overhead() * max_tensors; + buffer.mem_buffer = malloc(buffer.mem_size); + buffer.backend = backend; + // size += 128 * max_tensors; // alignment overhead + buffer.backend_buffer = backend->interface->alloc_buffer(backend->context, size); + return buffer; +} + +void ggml_backend_free_buffer(struct ggml_buffer * buffer) { + struct ggml_backend * backend = buffer->backend; + backend->interface->free_buffer(backend->context, buffer->backend_buffer); + free(buffer->mem_buffer); +} + +// backend copy + +static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + +void ggml_backend_cpy_tensor(struct ggml_tensor * dst, struct ggml_tensor * src) { + //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); + //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); + GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); + + // printf("cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src)); + + if (src == dst) { + return; + } + + if (dst->backend->interface->cpy_tensor_from != NULL) { + dst->backend->interface->cpy_tensor_from(dst->backend->context, src, dst); + } else if (src->backend->interface->cpy_tensor_to != NULL) { + src->backend->interface->cpy_tensor_to(src->backend->context, src, dst); + } else { + // not ideal, but shouldn't be hit when copying from/to CPU + // TODO: print a performance warning in debug builds + size_t nbytes = ggml_nbytes(src); + void * data = malloc(nbytes); + ggml_backend_get_tensor(src, data, 0, nbytes); + ggml_backend_set_tensor(dst, data, 0, nbytes); + free(data); + } +} + +// backend CPU + +struct ggml_backend_cpu_context { + int n_threads; + void * work_data; + size_t work_size; +}; + +static const char * ggml_backend_cpu_name(ggml_backend_context_t ctx) { + return "CPU"; + + UNUSED(ctx); +} + +static void ggml_backend_cpu_free_context(ggml_backend_context_t ctx) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)ctx; + free(cpu_ctx->work_data); + free(ctx); +} + +struct cpu_backend_buffer { + void * data; + size_t offset; + size_t size; +}; + +static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 + +static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) { + assert(alignment && !(alignment & (alignment - 1))); // power of 2 + size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment; + return offset + align; +} + +static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_context_t ctx, size_t size) { + struct cpu_backend_buffer * buffer = malloc(sizeof(struct cpu_backend_buffer)); + buffer->data = malloc(size); + buffer->offset = aligned_offset(buffer->data, 0, TENSOR_ALIGNMENT); + buffer->size = size; + return buffer; + + UNUSED(ctx); +} + +static void ggml_backend_cpu_free_buffer(ggml_backend_context_t ctx, ggml_backend_buffer_t buffer) { + struct cpu_backend_buffer * cpu_buffer = (struct cpu_backend_buffer *)buffer; + free(cpu_buffer->data); + free(cpu_buffer); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_reset_buffer(ggml_backend_context_t ctx, ggml_backend_buffer_t buffer) { + struct cpu_backend_buffer * cpu_buffer = (struct cpu_backend_buffer *)buffer; + cpu_buffer->offset = aligned_offset(cpu_buffer->data, 0, TENSOR_ALIGNMENT); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_alloc_tensor(ggml_backend_context_t ctx, ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + struct cpu_backend_buffer * cpu_buffer = (struct cpu_backend_buffer *)buffer; + + // TODO: make this error recoverable + if (cpu_buffer->offset + ggml_nbytes(tensor) > cpu_buffer->size) { + fprintf(stderr, "%s: not enough space in the buffer (needed %zu, available %zu)\n", + __func__, ggml_nbytes(tensor), cpu_buffer->size - cpu_buffer->offset); + GGML_ASSERT(false); + } + + tensor->data = (char*)cpu_buffer->data + cpu_buffer->offset; + cpu_buffer->offset = aligned_offset(cpu_buffer->data, cpu_buffer->offset + ggml_nbytes(tensor), TENSOR_ALIGNMENT); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_set_tensor_async(ggml_backend_context_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy((char *)tensor->data + offset, data, size); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_get_tensor_async(ggml_backend_context_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy(data, (const char *)tensor->data + offset, size); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_synchronize(ggml_backend_context_t ctx) { + UNUSED(ctx); +} + +static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_context_t ctx, struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_get_tensor(src, dst->data, 0, ggml_nbytes(src)); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_context_t ctx, struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_set_tensor(dst, src->data, 0, ggml_nbytes(src)); + + UNUSED(ctx); +} + +struct ggml_backend_cpu_plan { + struct ggml_cplan cplan; + struct ggml_cgraph cgraph; +}; + +static ggml_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_context_t ctx, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)ctx; + + struct ggml_backend_cpu_plan * cpu_plan = malloc(sizeof(struct ggml_backend_cpu_plan)); + + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cgraph = *cgraph; + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + } + + return cpu_plan; +} + +static void ggml_backend_cpu_graph_plan_free(ggml_backend_context_t ctx, ggml_graph_plan_t plan) { + struct ggml_backend_cpu_plan * cpu_plan = (struct ggml_backend_cpu_plan *)plan; + + free(cpu_plan->cplan.work_data); + free(cpu_plan); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_graph_plan_compute(ggml_backend_context_t ctx, ggml_graph_plan_t plan) { + struct ggml_backend_cpu_plan * cpu_plan = (struct ggml_backend_cpu_plan *)plan; + + ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + UNUSED(ctx); +} + +static void ggml_backend_cpu_graph_compute(ggml_backend_context_t ctx, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)ctx; + + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + + if (cpu_ctx->work_size < cplan.work_size) { + // TODO: may be faster to free and use malloc to avoid the copy + cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); + cpu_ctx->work_size = cplan.work_size; + } + + cplan.work_data = cpu_ctx->work_data; + + ggml_graph_compute(cgraph, &cplan); +} + +static struct ggml_backend_interface cpu_backend_interface = { + /* .get_name = */ ggml_backend_cpu_name, + /* .free_context = */ ggml_backend_cpu_free_context, + /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer, + /* .free_buffer = */ ggml_backend_cpu_free_buffer, + /* .reset_buffer = */ ggml_backend_cpu_reset_buffer, + /* .alloc_tensor = */ ggml_backend_cpu_alloc_tensor, + /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async, + /* .synchronize = */ ggml_backend_cpu_synchronize, + /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from, + /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to, + /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cpu_graph_compute +}; + +struct ggml_backend ggml_backend_cpu_init(void) { + struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; + + struct ggml_backend cpu_backend = { + /* .interface = */ &cpu_backend_interface, + /* .context = */ ctx + }; + return cpu_backend; +} + +void ggml_backend_cpu_set_n_threads(struct ggml_backend * backend_cpu, int n_threads) { + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +// splits + +struct ggml_graph_splits ggml_graph_split_init(void) { + struct ggml_graph_splits splits = {0}; + return splits; +} + +// TODO: this can be removed after allocating the graphs in a ggml_context +void ggml_graph_splits_free(struct ggml_graph_splits * splits) { + for (int i = 0; i < splits->n_splits; i++) { + if (splits->splits[i].graph) { + free(splits->splits[i].graph); + } + } +} + +void ggml_graph_splits_add_n_va(struct ggml_graph_splits * splits, struct ggml_tensor *** inputs, struct ggml_context * ctx, const char * fmt, va_list args) { + GGML_ASSERT(splits->n_splits < GGML_MAX_SPLITS); + + struct ggml_graph_split * split = &splits->splits[splits->n_splits]; + + if ((*inputs[0])->backend == ggml_get_ctx_backend(ctx)) { + if (splits->n_splits > 0) { + char name[GGML_MAX_NAME - 1]; // silence -Wformat-truncation + vsnprintf(name, sizeof(name), fmt, args); + char new_name[GGML_MAX_NAME]; + snprintf(new_name, sizeof(new_name), "%s,%s", splits->splits[splits->n_splits - 1].name, name); + strcpy(splits->splits[splits->n_splits - 1].name, new_name); + return; + } + // always add the first split + int i = 0; + while (inputs[i] != NULL) { + GGML_ASSERT(i < GGML_MAX_SPLIT_INPUTS); + split->src_inputs[i] = *inputs[i]; + split->dst_inputs[i] = *inputs[i]; + i++; + } + split->src_inputs[i] = NULL; + split->dst_inputs[i] = NULL; + } else { + int i = 0; + while (inputs[i] != NULL) { + GGML_ASSERT(i < GGML_MAX_SPLIT_INPUTS); + split->src_inputs[i] = *inputs[i]; + split->dst_inputs[i] = ggml_dup_tensor(ctx, *inputs[i]); + // TODO: maybe support different layings in ggml_backend_cpy_tensor instead + for (int j = 0; j < GGML_MAX_DIMS; j++) { + split->dst_inputs[i]->nb[j] = split->src_inputs[i]->nb[j]; + } + ggml_set_name(split->dst_inputs[i], ggml_get_name(*inputs[i])); + *inputs[i] = split->dst_inputs[i]; + i++; + } + split->src_inputs[i] = NULL; + split->dst_inputs[i] = NULL; + } + + vsnprintf(split->name, GGML_MAX_NAME, fmt, args); + split->graph = NULL; + splits->n_splits++; +} + +void ggml_graph_splits_add_n(struct ggml_graph_splits * splits, struct ggml_tensor *** input, struct ggml_context * ctx, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + ggml_graph_splits_add_n_va(splits, input, ctx, fmt, args); + va_end(args); +} + +void ggml_graph_splits_add(struct ggml_graph_splits * splits, struct ggml_tensor ** input, struct ggml_context * ctx, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + ggml_graph_splits_add_n_va(splits, (struct ggml_tensor**[2]){ input, NULL }, ctx, fmt, args); + va_end(args); +} + +void ggml_graph_splits_build_forward(struct ggml_graph_splits * splits, struct ggml_tensor * output) { + struct ggml_tensor *last_outputs[2] = { output, NULL }; + struct ggml_tensor ** outputs; + + for (int i = 0; i < splits->n_splits; i++) { + struct ggml_graph_split * split = &splits->splits[i]; + + if (i < splits->n_splits - 1) { + outputs = splits->splits[i + 1].src_inputs; + } else { + outputs = last_outputs; + } + + // build the graph + // TODO: allocate graphs in context + split->graph = (struct ggml_cgraph *) malloc(sizeof(struct ggml_cgraph)); + memset(split->graph, 0, sizeof(struct ggml_cgraph)); + // *split->graph = ggml_build_forward_range(output, split->input); + // *split->graph = ggml_build_forward(output); + for (int j = 0; outputs[j] != NULL; j++) { + ggml_build_forward_expand(split->graph, outputs[j]); + } + + for (int j = 1; j < split->graph->n_nodes; j++) { + if (split->graph->nodes[j]->backend != split->graph->nodes[0]->backend) { + fprintf(stderr, "split %s: node %s has different backend (%s) than the first node (%s)\n", + split->name, split->graph->nodes[j]->name, + ggml_backend_name(split->graph->nodes[j]->backend), + ggml_backend_name(split->graph->nodes[0]->backend)); + } + } + for (int j = 1; j < split->graph->n_leafs; j++) { + if (split->graph->leafs[j]->backend != split->graph->leafs[0]->backend) { + fprintf(stderr, "split %s: leaf %s has different backend (%s) than the first leaf (%s)\n", + split->name, split->graph->leafs[j]->name, + ggml_backend_name(split->graph->leafs[j]->backend), + ggml_backend_name(split->graph->leafs[0]->backend)); + } + } + } + + // close graphs + for (int i = 0; i < splits->n_splits; i++) { + struct ggml_graph_split * split = &splits->splits[i]; + ggml_graph_close(split->graph); + } +} + +void ggml_graph_splits_compute(struct ggml_graph_splits * splits) { + uint64_t copy_us = 0; + uint64_t compute_cpu_us = 0; + uint64_t compute_gpu_us = 0; + int n_nodes = 0; + for (int i = 0; i < splits->n_splits; i++) { + struct ggml_graph_split * split = &splits->splits[i]; + + //printf("computing split %i (%s) on backend %s (%i nodes)\n", i, split->name, ggml_backend_name(split->dst_inputs[0]->backend), split->graph->n_nodes); + + // copy the input tensor to the backend + uint64_t copy_start_us = ggml_time_us(); + for (int j = 0; split->src_inputs[j] != NULL; j++) { + if (split->src_inputs[j] != split->dst_inputs[j]) { + //printf("\tcopying tensor %d (%s) (%lu bytes)\n", j, split->src_inputs[j]->name, ggml_nbytes(split->src_inputs[j])); + ggml_backend_cpy_tensor(split->dst_inputs[j], split->src_inputs[j]); + } + } + ggml_backend_synchronize(split->dst_inputs[0]->backend); + copy_us += ggml_time_us() - copy_start_us; + +#if 0 + char split_filename[GGML_MAX_NAME]; + snprintf(split_filename, GGML_MAX_NAME, "split_%i.dot", i); + ggml_graph_dump_dot(split->graph, NULL, split_filename); +#endif + uint64_t start = ggml_time_us(); + ggml_backend_graph_compute(split->dst_inputs[0]->backend, split->graph); + ggml_backend_synchronize(split->dst_inputs[0]->backend); + uint64_t end = ggml_time_us(); + if (strcmp(ggml_backend_name(split->dst_inputs[0]->backend), "CPU") == 0) { + compute_cpu_us += end - start; + } else { + compute_gpu_us += end - start; + } + + n_nodes += split->graph->n_nodes; + } + + //printf("splits: %d, nodes: %d, copy: %.2fms, compute_cpu: %.2fms, compute_gpu: %.2fms\n", splits->n_splits, n_nodes, copy_us / 1000.0, compute_cpu_us / 1000.0, compute_gpu_us / 1000.0); + //exit(0); +} diff --git a/ggml-backend.h b/ggml-backend.h new file mode 100644 index 000000000..ce5aac2b5 --- /dev/null +++ b/ggml-backend.h @@ -0,0 +1,129 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + + typedef void * ggml_graph_plan_t; + typedef void * ggml_backend_context_t; + typedef void * ggml_backend_buffer_t; + struct ggml_backend; + + // buffers have space for the tensor structs in host memory, and tensor data in backend-specific memory + struct ggml_buffer { + // host memory + size_t mem_size; + void * mem_buffer; + + // tensor data + struct ggml_backend * backend; + ggml_backend_buffer_t backend_buffer; // backend-specific data + }; + + struct ggml_backend_interface { + const char * (*get_name)(ggml_backend_context_t ctx); + + void (*free_context)(ggml_backend_context_t ctx); + + // buffers + ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_context_t ctx, size_t size); + void (*free_buffer) (ggml_backend_context_t ctx, ggml_backend_buffer_t buffer); + void (*reset_buffer)(ggml_backend_context_t ctx, ggml_backend_buffer_t buffer); + void (*alloc_tensor)(ggml_backend_context_t ctx, ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + + // TODO: pinned buffers for faster transfers between host and device + + // tensor data access + // these functions can be asynchronous. helper functions are provided for synchronous access that automatically call synchronize + void (*set_tensor_async)(ggml_backend_context_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(ggml_backend_context_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*synchronize)(ggml_backend_context_t ctx); + + // (optional) copy tensor between different backends, allow for single-copy tranfers + void (*cpy_tensor_from)(ggml_backend_context_t ctx, struct ggml_tensor * src, struct ggml_tensor * dst); + void (*cpy_tensor_to) (ggml_backend_context_t ctx, struct ggml_tensor * src, struct ggml_tensor * dst); + + + // compute graph with a plan + ggml_graph_plan_t (*graph_plan_create) (ggml_backend_context_t ctx, struct ggml_cgraph * cgraph); + void (*graph_plan_free) (ggml_backend_context_t ctx, ggml_graph_plan_t plan); + void (*graph_plan_compute)(ggml_backend_context_t ctx, ggml_graph_plan_t plan); + + // compute graph without a plan + void (*graph_compute) (ggml_backend_context_t ctx, struct ggml_cgraph * cgraph); + + // check if a backend supports a given operation + // this could be used to fallback automatically to the CPU backend if a backend doesn't support an operation + // bool (*supports_op)(ggml_backend_context_t ctx, struct ggml_tensor * op); + }; + + struct ggml_backend { + struct ggml_backend_interface * interface; + ggml_backend_context_t context; + }; + + // backend helper functions + static inline const char * ggml_backend_name(struct ggml_backend * backend) { return backend->interface->get_name(backend->context); } + static inline void ggml_backend_free_context(struct ggml_backend * backend) { backend->interface->free_context(backend->context); } + static inline void ggml_backend_set_tensor_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { tensor->backend->interface->set_tensor_async(tensor->backend->context, tensor, data, offset, size); } + static inline void ggml_backend_get_tensor_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { tensor->backend->interface->get_tensor_async(tensor->backend->context, tensor, data, offset, size); } + static inline void ggml_backend_set_tensor(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { tensor->backend->interface->set_tensor_async(tensor->backend->context, tensor, data, offset, size); tensor->backend->interface->synchronize(tensor->backend->context); } + static inline void ggml_backend_get_tensor(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { tensor->backend->interface->get_tensor_async(tensor->backend->context, tensor, data, offset, size); tensor->backend->interface->synchronize(tensor->backend->context); } + static inline void ggml_backend_synchronize(struct ggml_backend * backend) { backend->interface->synchronize(backend->context); } + static inline ggml_graph_plan_t ggml_backend_graph_plan_create(struct ggml_backend * backend, struct ggml_cgraph * cgraph) { return backend->interface->graph_plan_create(backend->context, cgraph); } + static inline void ggml_backend_graph_plan_free(struct ggml_backend * backend, ggml_graph_plan_t plan) { backend->interface->graph_plan_free(backend->context, plan); } + static inline void ggml_backend_graph_plan_compute(struct ggml_backend * backend, ggml_graph_plan_t plan) { backend->interface->graph_plan_compute(backend->context, plan); } + static inline void ggml_backend_graph_compute(struct ggml_backend * backend, struct ggml_cgraph * cgraph) { backend->interface->graph_compute(backend->context, cgraph); } + + // buffer and tensor allocation + GGML_API struct ggml_buffer ggml_backend_alloc_buffer(struct ggml_backend * backend, size_t size, size_t max_tensors); + GGML_API void ggml_backend_free_buffer(struct ggml_buffer * buffer); + static inline void ggml_backend_reset_buffer(struct ggml_buffer * buffer) { buffer->backend->interface->reset_buffer(buffer->backend->context, buffer->backend_buffer); } + static inline void ggml_backend_alloc_tensor(struct ggml_buffer * buffer, struct ggml_tensor * tensor) { buffer->backend->interface->alloc_tensor(buffer->backend->context, buffer->backend_buffer, tensor); } + + // tensor copy between different backends + GGML_API void ggml_backend_cpy_tensor(struct ggml_tensor * dst, struct ggml_tensor * src); + + // CPU backend + GGML_API struct ggml_backend ggml_backend_cpu_init(void); + GGML_API void ggml_backend_cpu_set_n_threads(struct ggml_backend * backend_cpu, int n_threads); + + /////////////////////////// + + // graph splitting + #define GGML_MAX_SPLITS 200 + #define GGML_MAX_SPLIT_INPUTS 4 + + struct ggml_graph_split { + char name[GGML_MAX_NAME]; + struct ggml_tensor * src_inputs[GGML_MAX_SPLIT_INPUTS + 1]; + struct ggml_tensor * dst_inputs[GGML_MAX_SPLIT_INPUTS + 1]; + struct ggml_cgraph * graph; + }; + + // TODO: this shouldn't be fixed size, allocate from ggml_context + struct ggml_graph_splits { + int n_splits; + struct ggml_graph_split splits[GGML_MAX_SPLITS]; + }; + + // TODO: allocate in ggml_context + struct ggml_graph_splits ggml_graph_split_init(void); + // this won't be needed once we can allocate graphs from a ggml_context + GGML_API void ggml_graph_splits_free(struct ggml_graph_splits * splits); + + // add a split to the graph - single and multiple inputs versions + GGML_API void ggml_graph_splits_add(struct ggml_graph_splits * splits, struct ggml_tensor ** input, struct ggml_context * ctx, const char * fmt, ...); + GGML_API void ggml_graph_splits_add_n(struct ggml_graph_splits * splits, struct ggml_tensor *** inputs, struct ggml_context * ctx, const char * fmt, ...); + + // build graphs for all splits + GGML_API void ggml_graph_splits_build_forward(struct ggml_graph_splits * splits, struct ggml_tensor * output); + + // compute + GGML_API void ggml_graph_splits_compute(struct ggml_graph_splits * splits); + +#ifdef __cplusplus +} +#endif diff --git a/ggml-cuda-kern.h b/ggml-cuda-kern.h new file mode 100644 index 000000000..7b279f02c --- /dev/null +++ b/ggml-cuda-kern.h @@ -0,0 +1,468 @@ +// kernels for ggml-cuda +#include +#include + + +template +using to_t_cuda_t = void (*)(const void * x, dst_t * y, int k, cudaStream_t stream); + +// support for vector types in generic code +template struct vec2_t_impl; +template<> struct vec2_t_impl { typedef half2 type; }; +template<> struct vec2_t_impl { typedef float2 type; }; + +template using vec2_t = typename vec2_t_impl::type; + +template inline __host__ __device__ vec2_t make_vec2_t(const T & x, const T & y); +template<> inline __host__ __device__ vec2_t make_vec2_t(const half & x, const half & y) { return __halves2half2(x, y); } +template<> inline __host__ __device__ vec2_t make_vec2_t(const float & x, const float & y) { return make_float2(x, y); } + +// the cuda headers define operators for half2, but not for float2 +// they are defined here to simplify generic code +inline __host__ __device__ float2 operator+(const float2 & a, const float2 & b) { return make_float2(a.x + b.x, a.y + b.y); } +inline __host__ __device__ float2 operator-(const float2 & a, const float2 & b) { return make_float2(a.x - b.x, a.y - b.y); } +inline __host__ __device__ float2 operator*(const float2 & a, const float2 & b) { return make_float2(a.x * b.x, a.y * b.y); } +inline __host__ __device__ float2 operator/(const float2 & a, const float2 & b) { return make_float2(a.x / b.x, a.y / b.y); } +inline __host__ __device__ float2 & operator+=( float2 & a, const float2 & b) { a.x += b.x; a.y += b.y; return a; } +inline __host__ __device__ float2 & operator-=( float2 & a, const float2 & b) { a.x -= b.x; a.y -= b.y; return a; } +inline __host__ __device__ float2 & operator*=( float2 & a, const float2 & b) { a.x *= b.x; a.y *= b.y; return a; } +inline __host__ __device__ float2 & operator/=( float2 & a, const float2 & b) { a.x /= b.x; a.y /= b.y; return a; } + +template +using dequantize_kernel_t = void (*)(const void * vx, const int ib, const int iqs, vec2_t & v); + +__device__ half sqrt(const half x) { return hsqrt(x); } +__device__ half exp(const half x) { return hexp(x); } +__device__ half2 exp(const half2 x) { return h2exp(x); } +__device__ half cos(const half x) { return hcos(x); } +__device__ half sin(const half x) { return hsin(x); } +__device__ half max(const half x, const half y) { return __hmax(x, y); } +__device__ half2 max(const half2 x, const half2 y) { return __hmax2(x, y); } + + +template struct op_max { __device__ T operator()(T a, T b) const { return max(a, b); } }; +template struct op_sum { __device__ T operator()(T a, T b) const { return a + b; } }; + +template class op_t, typename T> +static inline __device__ T warp_reduce_all(T val) { + op_t op; +#pragma unroll + for (int mask = warpSize/2; mask > 0; mask /= 2) { + val = op(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + } + return val; +} + +template +static __device__ T zero_init() { return T(0); } +template<> +__device__ half2 zero_init() { return half2(0.0f, 0.0f); } + +template class op_t, typename T> +static __device__ T block_reduce_all(const T val, const T init = zero_init()) { + const int warp_id = threadIdx.x / warpSize; // warp id within the block + const int lane_id = threadIdx.x % warpSize; // lane id within the warp + const int num_warps = blockDim.x / warpSize; // number of warps in the block + + __shared__ T lane_result[32]; // max 32 warps per block + + // reduce warps + T warp_reduction = warp_reduce_all(val); + + __syncthreads(); + + // first thread within a warp writes reduction to shared memory + if (lane_id == 0) { + lane_result[warp_id] = warp_reduction; + } + + // wait for all warps to finish writing their reductions + __syncthreads(); + + // reduce the results of all warps + T block_reduction = init; + if (lane_id < num_warps) { + block_reduction = lane_result[lane_id]; + } + + block_reduction = warp_reduce_all(block_reduction); + + return block_reduction; +} + +template +static __device__ void convert_fp16(const void * vx, const int ib, const int iqs, vec2_t & v) { + const half * x = (const half *) vx; + + v.x = (dst_t)(x[ib + iqs + 0]); + v.y = (dst_t)(x[ib + iqs + 1]); +} + +template +static __device__ void convert_fp32(const void * vx, const int ib, const int iqs, vec2_t & v) { + const float * x = (const float *) vx; + + v.x = (dst_t)(x[ib + iqs + 0]); + v.y = (dst_t)(x[ib + iqs + 1]); +} + +template +static __global__ void k_mul_mat_p021(const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { + const src0_t * x = vx; + // const int col_x = blockDim.x*blockIdx.x + threadIdx.x; + // const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + dst_t tmp = 0; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + // x is transposed and permuted + const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; + const dst_t xi = (dst_t)(x[ix]); + + const int row_y = col_x; + + // y is not transposed but permuted + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // dst is not transposed and not permuted + const int idst = channel*nrows_dst + row_dst; + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +template +static __global__ void k_mul_mat_vec_nc( + const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, + const int row_stride_x, const int nchannels_x, const int channel_stride_x) { + + const src0_t * x = vx; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + const int idst = channel*nrows_dst + row_dst; + + dst_t tmp = 0; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; + const dst_t xi = (dst_t)(x[ix]); + + const int row_y = col_x; + + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +template +static __global__ void k_cpy(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const int i02 = i / (ne00*ne01); + const int i01 = (i - i02*ne01*ne00) / ne00; + const int i00 = i - i02*ne01*ne00 - i01*ne00; + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + + const int i12 = i / (ne10*ne11); + const int i11 = (i - i12*ne10*ne11) / ne10; + const int i10 = i - i12*ne10*ne11 - i11*ne10; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + + *(dst_t *)(cdst + dst_offset) = *(const src_t *)(cx + x_offset); +} + +template +static __global__ void k_add(const src0_t * x, const src1_t * y, dst_t * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = (dst_t)x[i] + (dst_t)y[i]; +} + +template +static __global__ void k_mul(const src0_t * x, const src1_t * y, dst_t * dst, const int kx, const int ky) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= kx) { + return; + } + dst[i] = (dst_t)x[i] * (dst_t)y[i%ky]; +} + +template +static __global__ void k_silu(const src0_t * x, dst_t * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] / (src0_t(1) + exp(-x[i])); +} + +// TODO: unstable with f16 compute, using f32 compute for now +template +static __global__ void k_rms_norm(const src0_t * x, dst_t * dst, const int ncols) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + const float eps = 1e-6; + + float tmp = 0; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += WARP_SIZE) { + const float xi = x[row*ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + const float mean = tmp / (float)ncols; + const float scale = 1.0f / sqrtf(mean + eps); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[row*ncols + col] = scale * (float)x[row*ncols + col]; + } +} + +template +static __global__ void k_rope(const src0_t * x, dst_t * dst, const int ncols, const float p, const float theta_scale) { + const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); + + if (col >= ncols) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + + const dst_t theta = p * powf(theta_scale, col/2); + const dst_t sin_theta = sin(theta); + const dst_t cos_theta = cos(theta); + + const dst_t x0 = x[i + 0]; + const dst_t x1 = x[i + 1]; + + dst[i + 0] = (dst_t)x0*cos_theta - (dst_t)x1*sin_theta; + dst[i + 1] = (dst_t)x0*sin_theta + (dst_t)x1*cos_theta; +} + +template +static __global__ void k_diag_mask_inf(const src0_t * x, dst_t * dst, const int ncols, const int rows_per_channel, const int n_past) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + const int row = blockDim.y*blockIdx.y + threadIdx.y; + + if (col >= ncols) { + return; + } + + const int i = row*ncols + col; + //dst[i] = col > (n_past + row % rows_per_channel) ? (dst_t)-INFINITY : (dst_t)x[i]; + dst[i] = (dst_t)x[i] - (dst_t)((col > n_past + row % rows_per_channel) * INT_MAX); // equivalent within rounding error but slightly faster on GPU +} + +// TODO: numerically stable version - low prio since the softmax is computed in the fused attention kernel +// check: https://arxiv.org/pdf/2001.04438.pdf +template +static __global__ void k_soft_max_orig(const src0_t * x, dst_t * dst, const int ncols) { + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int block_size = blockDim.x; + const int tid = threadIdx.x; + + float tmp = 0; + + for (int block_start = 0; block_start < ncols; block_start += block_size) { + const int col = block_start + tid; + + if (col >= ncols) { + break; + } + + const int i = row*ncols + col; + const float val = expf(x[i]); + tmp += val; + dst[i] = val; + } + + // sum up partial sums +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + for (int block_start = 0; block_start < ncols; block_start += block_size) { + const int col = block_start + tid; + + if (col >= ncols) { + break; + } + + const int i = row*ncols + col; + dst[i] /= tmp; + } +} + +template +static __global__ void k_soft_max(const src_t * x, dst_t * dst, const int64_t nrows, const int64_t ncols) { + //assert(ncols % pack_size == 0); + const int tid = threadIdx.x; + const int num_packs = ncols / pack_size; + + for (int row = blockIdx.x; row < nrows; row += gridDim.x) { + src_t th_max = -INFINITY; + // row max thread + #pragma unroll + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + // load pack + src_t pack[pack_size]; + #pragma unroll + for (int i = 0; i < pack_size; i++) { + pack[i] = x[row * ncols + pack_id * pack_size + i]; + } + // reduce max pack + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + th_max = max(th_max, pack[i]); + } + } + // reduce max row warp threads + src_t row_max = block_reduce_all(th_max, (src_t)-INFINITY); + + // row exp sum thread + src_t th_sum = 0; + #pragma unroll + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + // load pack + src_t pack[pack_size]; + #pragma unroll + for (int i = 0; i < pack_size; i++) { + pack[i] = x[row * ncols + pack_id * pack_size + i]; + } + // reduce pack + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + th_sum += exp(pack[i] - row_max); + } + } + + // reduce row exp sum all threads + src_t row_sum = block_reduce_all(th_sum); + + // store (row - row_max) / row exp sum + #pragma unroll + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + // load pack + src_t pack[pack_size]; + #pragma unroll + for (int i = 0; i < pack_size; i++) { + pack[i] = x[row * ncols + pack_id * pack_size + i]; + } + // reduce pack + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] = exp(pack[i] - row_max) / row_sum; + } + + // store pack + #pragma unroll + for (int i = 0; i < pack_size; i++) { + dst[row * ncols + pack_id * pack_size + i] = pack[i]; + } + } + } +} + +template +static __global__ void k_scale(const src0_t * x, dst_t * dst, const src1_t * scale, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (dst_t)(*scale) * (dst_t)x[i]; +} + +template dequantize_kernel> +static __global__ void k_get_rows(const void * x, const int * y, dst_t * dst, const int ncols) { + const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2; + const int row = blockDim.y*blockIdx.y + threadIdx.y; + + if (col >= ncols) { + return; + } + + const int r = y[row]; + + // copy x[r*ncols + col] to dst[row*ncols + col] + const int xi = r*ncols + col; + const int di = row*ncols + col; + + const int ib = xi/qk; // block index + const int iqs = (xi%qk)/qr; // quant index + const int iybs = di - di%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + vec2_t v; + dequantize_kernel(x, ib, iqs, v); + dst[iybs + iqs + 0] = v.x; + dst[iybs + iqs + y_offset] = v.y; +} diff --git a/ggml-cuda-quant.h b/ggml-cuda-quant.h new file mode 100644 index 000000000..1afcef04a --- /dev/null +++ b/ggml-cuda-quant.h @@ -0,0 +1,920 @@ +// quants kernels for ggml-cuda + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QK4_0 32 +#define QR4_0 2 +#define QI4_0 4 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +#define QR4_1 2 +#define QI4_1 4 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +#define QR5_0 2 +#define QI5_0 4 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +#define QR5_1 2 +#define QI5_1 4 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +#define QR8_0 1 +#define QI8_0 8 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 8 +typedef struct { + half d; // delta + half s; // unquantized sum + int8_t qs[QK8_0]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); + +//================================= k-quants + +#define QK_K 256 + +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +typedef struct { + uint8_t hmask[QK_K/8]; + uint8_t qs[QK_K/4]; // nibbles / quants + uint8_t scales[3*QK_K/64]; + half d; +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); + +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); + + +template +using dot_kernel_k_t = void (*)(const void * vx, const int ib, const int iqs, const src1_t * y, dst_t & v); + +template +using vec_dot_q_cuda_t = dst_t (*)(const void * vbq, const block_q8_1 * bq8_1, const int iqs); + + +// TODO: f16 +template +static __global__ void quantize_q8_1(const src_t * x, void * vy, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i / QK8_0; // block index + const int iqs = i % QK8_0; // quant index + + const float xi = x[i]; + float amax = fabsf(xi); + float sum = xi; + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); + sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + y[ib].d = d; + y[ib].s = sum; +} + +template +static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, vec2_t & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dst_t d = x[ib].d; + + const uint8_t vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + + const vec2_t off2 = make_vec2_t(8, 8); + const vec2_t d2 = make_vec2_t(d, d); + + v = (v - off2) * d2; +} + +template +static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, vec2_t & v){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const dst_t d = x[ib].d; + const dst_t m = x[ib].m; + + const uint8_t vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + + const vec2_t d2 = make_vec2_t(d, d); + const vec2_t m2 = make_vec2_t(m, m); + + v = v * d2 + m2; +} + +template +static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, vec2_t & v){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const dst_t d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + + const vec2_t off2 = make_vec2_t(16, 16); + const vec2_t d2 = make_vec2_t(d, d); + + v = (v - off2) * d2; +} + +template +static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, vec2_t & v){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const dst_t d = x[ib].d; + const dst_t m = x[ib].m; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + + const vec2_t d2 = make_vec2_t(d, d); + const vec2_t m2 = make_vec2_t(m, m); + + v = v * d2 + m2; +} + +template +static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, vec2_t & v){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const dst_t d = x[ib].d; + + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; + + const vec2_t d2 = make_vec2_t(d, d); + + v = v * d2; +} + +//================================== k-quants + +static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { + + const int i = blockIdx.x; + const int tid = threadIdx.x; + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const block_q2_K * x = (const block_q2_K *) vx; + + const uint8_t q = x[i].qs[32*n + l]; + float * y = yy + i*QK_K + 128*n; + + float dall = x[i].d; + float dmin = x[i].dmin; + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); + +} + +static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q2_K * x = (const block_q2_K *) vx; + + // if n is 0, we want to do the lower 128, else the upper 128, + // covering y[l+0], y[l+32], y[l+64], y[l+96] and + // y[l+16], y[l+48], y[l+80], y[l+112] + int n = iqs/128; // 0 or 1 + int r = iqs - 128*n; // 0...120 in steps of 8 + int l = r/8; // 0...15 in steps of 1 + + const float * y = yy + 128*n + l; + const uint8_t * q = x[ib].qs + 32*n + l; + const uint8_t * s = x[ib].scales + 8*n; + + const float dall = x[ib].d; + const float dmin = x[ib].dmin; + + float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4)) + + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4)) + + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4)) + + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4)) + + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4)) + + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4)) + + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4)) + + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4)); + + result = sum; + +} + +static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { + + int r = threadIdx.x/4; + int i = blockIdx.x; + int tid = r/2; + int is0 = r%2; + int l0 = 16*is0 + 4*(threadIdx.x%4); + int n = tid / 4; + int j = tid - 4*n; + + const block_q3_K * x = (const block_q3_K *) vx; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + float * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + +} + +static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q3_K * x = (const block_q3_K *) vx; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[3]; + uint32_t utmp[4]; + + // if n is 0, we want to do the lower 128, else the upper 128, + // covering y[l+0], y[l+32], y[l+64], y[l+96] and + // y[l+16], y[l+48], y[l+80], y[l+112] + int n = iqs/128; // 0 or 1 + int r = iqs - 128*n; // 0...120 in steps of 8 + int l = r/8; // 0...15 in steps of 1 + + const float * y = yy + 128*n + l; + const uint8_t * q = x[ib].qs + 32*n + l; + const uint8_t * hm = x[ib].hmask + l; + const int8_t * s = (const int8_t *)utmp + 8*n; + + memcpy(aux, x[ib].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + const float dall = x[ib].d; + + const uint8_t m = 1 << (4*n); + + float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4)) + + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4)) + + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4)) + + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4)) + + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4)) + + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4)) + + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4)) + + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4)); + + result = sum * dall; + +} + +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int i = blockIdx.x; + + //// assume 64 threads - this is very slightly better than the one below + //const int tid = threadIdx.x; + //const int il = tid/16; + //const int ir = tid%16; + //const int is = 2*il; + //const int n = 2; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + float * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +} + +static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q4_K * x = (const block_q4_K *) vx; + + // iqs is in 0...248 in steps of 8 => + const int j = iqs / 64; // j is in 0...3 + const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 + const int is = 2*j; // is is in 0...6 in steps of 2 + + const float * y = yy + 64*j + ir; + const uint8_t * q = x[ib].qs + 32*j + ir; + + const float dall = x[ib].d; + const float dmin = x[ib].dmin; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[ib].scales, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[ib].scales, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + float sum = 0; + for (int k = 0; k < 4; ++k) { + sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1); + sum += y[k + 32] * (d2 * (q[k] >> 4) - m2); + } + result = sum; + +} + +static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + float * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +} + +static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { + + const block_q5_K * x = (const block_q5_K *) vx; + + // iqs is in 0...248 in steps of 8 => + const int j = iqs / 64; // j is in 0...3 + const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 + const int is = 2*j; // is is in 0...6 in steps of 2 + + const float * y = yy + 64*j + ir; + const uint8_t * ql = x[ib].qs + 32*j + ir; + const uint8_t * qh = x[ib].qh + ir; + + const float dall = x[ib].d; + const float dmin = x[ib].dmin; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[ib].scales, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[ib].scales, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + uint8_t hm = 1 << is; + float sum = 0; + for (int k = 0; k < 4; ++k) { + sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1); + } + hm <<= 1; + for (int k = 0; k < 4; ++k) { + sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2); + } + result = sum; + +} + +template +static __global__ void dequantize_block_q6_K(const void * vx, dst_t * yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + // TODO: fp16 compute + dst_t * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + +template +static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const src1_t * yy, dst_t * dst, const int ncols, int nrows) { + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + dst_t tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const src1_t * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const dst_t d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + dst_t sum = 0; + for (int l = 0; l < 4; ++l) { + sum += (dst_t)y[l+ 0] * (dst_t)s[0] * d * (dst_t)((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + (dst_t)y[l+32] * (dst_t)s[2] * d * (dst_t)((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + (dst_t)y[l+64] * (dst_t)s[4] * d * (dst_t)((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + (dst_t)y[l+96] * (dst_t)s[6] * d * (dst_t)((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +template dequantize_kernel> +static __global__ void dequantize_block(const void * vx, dst_t * y, const int k) { + const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; + + if (i >= k) { + return; + } + + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + vec2_t v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; +} + +template +static __device__ __forceinline__ dst_t vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int vi; + memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]); + + const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d); + + // subtract 8 from each quantized value + const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808); + const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808); + + // SIMD dot product of quantized values + int sumi = __dp4a(vi0, ui0, 0); + sumi = __dp4a(vi1, ui1, sumi); + + return sumi*d; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +template +static __device__ __forceinline__ dst_t vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]); + + const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d); + const float m = bq4_1->m; + const float s = bq8_1->s; + + const int vi0 = (vi >> 0) & 0x0F0F0F0F; + const int vi1 = (vi >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + int sumi = __dp4a(vi0, ui0, 0); + sumi = __dp4a(vi1, ui1, sumi); + + return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +template +static __device__ __forceinline__ dst_t vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int qs; + memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); + const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2); + const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]); + + const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d); + + int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits + vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 + vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 + vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 + vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 + vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values + int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values + + int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits + vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 + vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 + vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 + vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 + vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values + sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values + + return sumi*d; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +template +static __device__ __forceinline__ dst_t vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); + const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2); + const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]); + + const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d); + const float m = bq5_1->m; + const float s = bq8_1->s; + + int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits + vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 + vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 + vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 + vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 + int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values + + int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits + vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 + vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 + vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 + vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 + sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values + + return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +template +static __device__ __forceinline__ dst_t vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int vi; + memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); + const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + + const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d); + + // SIMD dot product of quantized values + int sumi = __dp4a(vi, ui, 0); + + return sumi*d; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +template vec_dot_q_cuda> +static __global__ void mul_mat_vec_q(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows) { + const int row = blockIdx.y*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index + + const int iby = i + threadIdx.x / qi; // y block index + + const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int + + tmp += (float)vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = (dst_t)tmp; + } +} + +template dequantize_kernel> +static __global__ void dequantize_mul_mat_vec(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.y*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int tid = threadIdx.x; + + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + const int y_offset = qr == 1 ? 1 : qk/2; + + vec2_t tmp2 = make_vec2_t(0, 0); // partial sum for thread in warp + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + + // dequantize + vec2_t xc; + dequantize_kernel(vx, ib, iqs + j/qr, xc); + + // matrix multiplication + vec2_t yc = make_vec2_t( + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset]); + tmp2 += xc * yc; + } + } + + // sum up partial sums and write back result + // TODO: reducing as half2 may be faster, but requires special handling for float2 + dst_t tmp = tmp2.x + tmp2.y; +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +template dot_kernel> +static __global__ void dequantize_mul_mat_vec_k(const void * vx, const src1_t * y, dst_t * dst, const int ncols) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + const int iter_stride = QK_K; + const int vals_per_iter = iter_stride / n_thread; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + dst_t tmp = 0; // partial sum for thread in warp + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = ib0 + col/QK_K; // x block index + const int iqs = col%QK_K; // x quant index + const int iybs = col - col%QK_K; // y block start index + + dst_t v; + dot_kernel(vx, ib, iqs, y + iybs, v); + tmp += v; + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0646fa7b2..d31823d81 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,19 +1,64 @@ +static const int GGML_CUDA_MAX_SUBSTREAMS = 1; +static const bool GGML_CUDA_SEQ_COMPUTE = true; + +#define WARP_SIZE 32 +#define CUDA_ADD_BLOCK_SIZE 256 +#define CUDA_MUL_BLOCK_SIZE 256 +#define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_CPY_BLOCK_SIZE 32 +#define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_CUDA_DMMV_X +#define GGML_CUDA_DMMV_X 32 +#endif +#ifndef GGML_CUDA_DMMV_Y +#define GGML_CUDA_DMMV_Y 1 +#endif +#ifndef GGML_CUDA_MMV_Y +#define GGML_CUDA_MMV_Y 1 +#endif + + +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + +#include +#include +#include +#include +#include #include #include #include +#include +#include #include #include -#include -#include +#include +#include +#include +#include +#include +#include #include #include -#include +#include +#include -#include "ggml-cuda.h" #include "ggml.h" - -#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#include "ggml-cuda.h" +#include "ggml-cuda-kern.h" +#include "ggml-cuda-quant.h" #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -25,8 +70,8 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cudaError_t err_ = (err); \ if (err_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ + fprintf(stderr, "CUDA error %d at %s (%s:%d): %s\n", err_, \ + __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); \ exit(1); \ } \ } while (0) @@ -36,8 +81,8 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cublasStatus_t err_ = (err); \ if (err_ != CUBLAS_STATUS_SUCCESS) { \ - fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ - err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ + fprintf(stderr, "\ncuBLAS error %d at %s (%s:%d): %s\n", err_, \ + __func__, __FILE__, __LINE__, cublasGetStatusString(err_)); \ exit(1); \ } \ } while (0) @@ -50,2037 +95,115 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); exit(1); \ } \ } while (0) -#endif // CUDART_VERSION >= 11 +#endif // CUDART_VERSION >= 12000 -#ifdef GGML_CUDA_DMMV_F16 -typedef half dfloat; // dequantize float -typedef half2 dfloat2; -#else -typedef float dfloat; // dequantize float -typedef float2 dfloat2; -#endif //GGML_CUDA_DMMV_F16 +#define UNUSED(x) (void)(x) -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); -typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream); -typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); -typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_cuda_op_t)( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, - float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main); + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t cudaStream_main); -// QK = number of values after dequantization -// QR = QK / number of values before dequantization -// QI = number of 32 bit integers before dequantization - -#define QK4_0 32 -#define QR4_0 2 -#define QI4_0 (QK4_0 / (4 * QR4_0)) -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); - -#define QK4_1 32 -#define QR4_1 2 -#define QI4_1 (QK4_1 / (4 * QR4_1)) -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); - -#define QK5_0 32 -#define QR5_0 2 -#define QI5_0 (QK5_0 / (4 * QR5_0)) -typedef struct { - half d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); - -#define QK5_1 32 -#define QR5_1 2 -#define QI5_1 (QK5_1 / (4 * QR5_1)) -typedef struct { - half d; // delta - half m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); - -#define QK8_0 32 -#define QR8_0 1 -#define QI8_0 (QK8_0 / (4 * QR8_0)) -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); - -#define QK8_1 32 -#define QR8_1 1 -#define QI8_1 (QK8_1 / (4 * QR8_1)) -typedef struct { - half d; // delta - half s; // unquantized sum - int8_t qs[QK8_0]; // quants -} block_q8_1; -static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); - -typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs); - -//================================= k-quants - -#ifdef GGML_QKK_64 -#define QK_K 64 -#define K_SCALE_SIZE 4 -#else -#define QK_K 256 -#define K_SCALE_SIZE 12 -#endif - -#define QR2_K 4 -#define QI2_K (QK_K / (4*QR2_K)) -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); - -#define QR3_K 4 -#define QI3_K (QK_K / (4*QR3_K)) -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#ifdef GGML_QKK_64 - uint8_t scales[2]; // scales, quantized with 8 bits -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; -//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); - -#define QR4_K 2 -#define QI4_K (QK_K / (4*QR4_K)) -#ifdef GGML_QKK_64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; // 4-bit block scales/mins - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); -#endif - -#define QR5_K 2 -#define QI5_K (QK_K / (4*QR5_K)) -#ifdef GGML_QKK_64 -typedef struct { - half d; // super-block scale - int8_t scales[QK_K/16]; // block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); -#endif - -#define QR6_K 2 -#define QI6_K (QK_K / (4*QR6_K)) -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales - half d; // delta -} block_q6_K; -static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); - -#define WARP_SIZE 32 -#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses - -#define CUDA_ADD_BLOCK_SIZE 256 -#define CUDA_MUL_BLOCK_SIZE 256 -#define CUDA_GELU_BLOCK_SIZE 256 -#define CUDA_SILU_BLOCK_SIZE 256 -#define CUDA_CPY_BLOCK_SIZE 32 -#define CUDA_SCALE_BLOCK_SIZE 256 -#define CUDA_ROPE_BLOCK_SIZE 256 -#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 -#define CUDA_QUANTIZE_BLOCK_SIZE 256 -#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 - -// dmmv = dequantize_mul_mat_vec -#ifndef GGML_CUDA_DMMV_X -#define GGML_CUDA_DMMV_X 32 -#endif -#ifndef GGML_CUDA_MMV_Y -#define GGML_CUDA_MMV_Y 1 -#endif - -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - -struct ggml_tensor_extra_gpu { - void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors - cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs +struct cuda_pool_buffer { + void * ptr; + size_t size; }; -static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= kx) { - return; - } - dst[i] = x[i] + y[i%ky]; -} - -static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = __hadd(x[i], __float2half(y[i])); -} - -static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= kx) { - return; - } - dst[i] = x[i] * y[i%ky]; -} - -static __global__ void gelu_f32(const float * x, float * dst, const int k) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - float xi = x[i]; - dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); -} - -static __global__ void silu_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = x[i] / (1.0f + expf(-x[i])); -} - -static __global__ void norm_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; - - const float eps = 1e-5f; - - float mean = 0.0f; - float var = 0.0f; - - for (int col = tid; col < ncols; col += WARP_SIZE) { - const float xi = x[row*ncols + col]; - mean += xi; - var += xi * xi; - } - - // sum up partial sums -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - mean += __shfl_xor_sync(0xffffffff, mean, mask, 32); - var += __shfl_xor_sync(0xffffffff, var, mask, 32); - } - - mean /= ncols; - var = var / ncols - mean * mean; - const float inv_var = rsqrtf(var + eps); - - for (int col = tid; col < ncols; col += WARP_SIZE) { - dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var; - } -} - -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; - - const float eps = 1e-6f; - - float tmp = 0.0f; // partial sum for thread in warp - - for (int col = tid; col < ncols; col += WARP_SIZE) { - const float xi = x[row*ncols + col]; - tmp += xi * xi; - } - - // sum up partial sums -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - const float mean = tmp / ncols; - const float scale = rsqrtf(mean + eps); - - for (int col = tid; col < ncols; col += WARP_SIZE) { - dst[row*ncols + col] = scale * x[row*ncols + col]; - } -} - -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q4_0 * x = (const block_q4_0 *) vx; - - const dfloat d = x[ib].d; - - const int vui = x[ib].qs[iqs]; - - v.x = vui & 0xF; - v.y = vui >> 4; - -#ifdef GGML_CUDA_DMMV_F16 - v = __hsub2(v, {8.0f, 8.0f}); - v = __hmul2(v, {d, d}); -#else - v.x = (v.x - 8.0f) * d; - v.y = (v.y - 8.0f) * d; -#endif // GGML_CUDA_DMMV_F16 -} - -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q4_1 * x = (const block_q4_1 *) vx; - - const dfloat d = x[ib].d; - const dfloat m = x[ib].m; - - const int vui = x[ib].qs[iqs]; - - v.x = vui & 0xF; - v.y = vui >> 4; - -#ifdef GGML_CUDA_DMMV_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_DMMV_F16 -} - -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q5_0 * x = (const block_q5_0 *) vx; - - const dfloat d = x[ib].d; - - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; - const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - - v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); - v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - -#ifdef GGML_CUDA_DMMV_F16 - v = __hsub2(v, {16.0f, 16.0f}); - v = __hmul2(v, {d, d}); -#else - v.x = (v.x - 16.0f) * d; - v.y = (v.y - 16.0f) * d; -#endif // GGML_CUDA_DMMV_F16 -} - -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q5_1 * x = (const block_q5_1 *) vx; - - const dfloat d = x[ib].d; - const dfloat m = x[ib].m; - - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; - const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - - v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); - v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - -#ifdef GGML_CUDA_DMMV_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_DMMV_F16 -} - -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q8_0 * x = (const block_q8_0 *) vx; - - const dfloat d = x[ib].d; - - v.x = x[ib].qs[iqs + 0]; - v.y = x[ib].qs[iqs + 1]; - -#ifdef GGML_CUDA_DMMV_F16 - v = __hmul2(v, {d, d}); -#else - v.x *= d; - v.y *= d; -#endif // GGML_CUDA_DMMV_F16 -} - -//================================== k-quants - -static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) { - - const int i = blockIdx.x; - const block_q2_K * x = (const block_q2_K *) vx; - - const int tid = threadIdx.x; -#if QK_K == 256 - const int n = tid/32; - const int l = tid - 32*n; - const int is = 8*n + l/16; - - const uint8_t q = x[i].qs[32*n + l]; - float * y = yy + i*QK_K + 128*n; - - float dall = x[i].d; - float dmin = x[i].dmin; - y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); - y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); - y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); - y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); -#else - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const uint8_t q = x[i].qs[il] >> (2*is); - float * y = yy + i*QK_K + 16*is + il; - float dall = x[i].d; - float dmin = x[i].dmin; - y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); - y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); -#endif - -} - -static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) { - - const int i = blockIdx.x; - const block_q3_K * x = (const block_q3_K *) vx; - -#if QK_K == 256 - const int r = threadIdx.x/4; - const int tid = r/2; - const int is0 = r%2; - const int l0 = 16*is0 + 4*(threadIdx.x%4); - const int n = tid / 4; - const int j = tid - 4*n; - - uint8_t m = 1 << (4*n + j); - int is = 8*n + 2*j + is0; - int shift = 2*j; - - int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : - (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); - float d_all = x[i].d; - float dl = d_all * (us - 32); - - float * y = yy + i*QK_K + 128*n + 32*j; - const uint8_t * q = x[i].qs + 32*n; - const uint8_t * hm = x[i].hmask; - - for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); -#else - const int tid = threadIdx.x; - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const int im = il/8; // 0...1 - const int in = il%8; // 0...7 - - float * y = yy + i*QK_K + 16*is + il; - - const uint8_t q = x[i].qs[il] >> (2*is); - const uint8_t h = x[i].hmask[in] >> (2*is + im); - const float d = (float)x[i].d; - - if (is == 0) { - y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); - y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); - } else { - y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); - y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); - } -#endif - -} - -#if QK_K == 256 -static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { - if (j < 4) { - d = q[j] & 63; m = q[j + 4] & 63; - } else { - d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - } -} -#endif - -static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) { - const block_q4_K * x = (const block_q4_K *) vx; - - const int i = blockIdx.x; - -#if QK_K == 256 - // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; - const int is = 2*il; - const int n = 4; - - float * y = yy + i*QK_K + 64*il + n*ir; - - const float dall = x[i].d; - const float dmin = x[i].dmin; - - const uint8_t * q = x[i].qs + 32*il + n*ir; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[i].scales, sc, m); - const float d1 = dall * sc; const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[i].scales, sc, m); - const float d2 = dall * sc; const float m2 = dmin * m; - for (int l = 0; l < n; ++l) { - y[l + 0] = d1 * (q[l] & 0xF) - m1; - y[l +32] = d2 * (q[l] >> 4) - m2; - } -#else - const int tid = threadIdx.x; - const uint8_t * q = x[i].qs; - float * y = yy + i*QK_K; - const float d = (float)x[i].d[0]; - const float m = (float)x[i].d[1]; - y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); - y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); -#endif -} - -static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) { - const block_q5_K * x = (const block_q5_K *) vx; - - const int i = blockIdx.x; - -#if QK_K == 256 - // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int il = tid/16; // il is in 0...3 - const int ir = tid%16; // ir is in 0...15 - const int is = 2*il; // is is in 0...6 - - float * y = yy + i*QK_K + 64*il + 2*ir; - - const float dall = x[i].d; - const float dmin = x[i].dmin; - - const uint8_t * ql = x[i].qs + 32*il + 2*ir; - const uint8_t * qh = x[i].qh + 2*ir; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[i].scales, sc, m); - const float d1 = dall * sc; const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[i].scales, sc, m); - const float d2 = dall * sc; const float m2 = dmin * m; - - uint8_t hm = 1 << (2*il); - y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; - y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; - hm <<= 1; - y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; - y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; -#else - const int tid = threadIdx.x; - const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 - const int is = tid/16; // 0 or 1 - const uint8_t h = x[i].qh[in] >> im; - const float d = x[i].d; - float * y = yy + i*QK_K + tid; - y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); - y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); -#endif -} - -static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) { - const block_q6_K * x = (const block_q6_K *) vx; - - const int i = blockIdx.x; -#if QK_K == 256 - - // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int ip = tid/32; // ip is 0 or 1 - const int il = tid - 32*ip; // 0...32 - const int is = 8*ip + il/16; - - float * y = yy + i*QK_K + 128*ip + il; - - const float d = x[i].d; - - const uint8_t * ql = x[i].ql + 64*ip + il; - const uint8_t qh = x[i].qh[32*ip + il]; - const int8_t * sc = x[i].scales + is; - - y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); - y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); - y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); - y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); -#else - - // assume 32 threads - const int tid = threadIdx.x; - const int ip = tid/16; // 0 or 1 - const int il = tid - 16*ip; // 0...15 - - float * y = yy + i*QK_K + 16*ip + il; - - const float d = x[i].d; - - const uint8_t ql = x[i].ql[16*ip + il]; - const uint8_t qh = x[i].qh[il] >> (2*ip); - const int8_t * sc = x[i].scales; - - y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); - y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); -#endif -} - -static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q2_K * x = (const block_q2_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - -#if QK_K == 256 - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 16/K_QUANTS_PER_ITERATION; - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int s_offset = 8*im; - const int y_offset = 128*im + l0; - - uint32_t aux[4]; - const uint8_t * d = (const uint8_t *)aux; - const uint8_t * m = (const uint8_t *)(aux + 2); - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - - const float dall = x[i].d; - const float dmin = x[i].dmin; - - const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = a[1] & 0x0f0f0f0f; - aux[2] = (a[0] >> 4) & 0x0f0f0f0f; - aux[3] = (a[1] >> 4) & 0x0f0f0f0f; - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) - + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) - + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) - + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) - + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) - + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) - + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) - +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); - sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] - + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; - - } - tmp += dall * sum1 - dmin * sum2; - - } -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; - - uint32_t uaux[2]; - const uint8_t * d = (const uint8_t *)uaux; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + offset; - const uint8_t * q = x[i].qs + offset; - const uint32_t * s = (const uint32_t *)x[i].scales; - - uaux[0] = s[0] & 0x0f0f0f0f; - uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; - - const half2 * dh = (const half2 *)&x[i].d; - - const float2 dall = __half22float2(dh[0]); - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - const uint8_t ql = q[l]; - sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) - + y[l+16] * d[1] * ((ql >> 2) & 3) - + y[l+32] * d[2] * ((ql >> 4) & 3) - + y[l+48] * d[3] * ((ql >> 6) & 3); - sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; - } - tmp += dall.x * sum1 - dall.y * sum2; - } -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q3_K * x = (const block_q3_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - -#if QK_K == 256 - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 - - const uint8_t m = 1 << (4*im); - - const int l0 = n*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int y_offset = 128*im + l0; - - uint16_t utmp[4]; - const int8_t * s = (const int8_t *)utmp; - - const uint16_t s_shift = 4*im; - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - const uint8_t * h = x[i].hmask + l0; - - const uint16_t * a = (const uint16_t *)x[i].scales; - utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); - utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); - utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); - utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); - - const float d = x[i].d; - - float sum = 0; - for (int l = 0; l < n; ++l) { - sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) - + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) - + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) - + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); - sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) - + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) - + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) - + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); - } - tmp += d * sum; - - } -#else - - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 - const int in = offset/8; // 0 or 1 - const int im = offset%8; // 0...7 - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + offset; - const uint8_t * q = x[i].qs + offset; - const uint8_t * s = x[i].scales; - - const float dall = (float)x[i].d; - - float sum = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - const uint8_t hl = x[i].hmask[im+l] >> in; - const uint8_t ql = q[l]; - sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) - + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) - + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) - + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); - } - tmp += sum; - } -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q4_K * x = (const block_q4_K *)vx + ib0; - -#if QK_K == 256 - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const uint8_t * q1 = x[i].qs + q_offset; - const uint8_t * q2 = q1 + 64; - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = x[i].d; - const float dmin = x[i].dmin; - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { - s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); - s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; - - } -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); - - const int step = tid * K_QUANTS_PER_ITERATION; - - uint16_t aux16[2]; - const uint8_t * s = (const uint8_t *)aux16; - - float tmp = 0; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - const uint8_t * q = x[i].qs + step; - const float * y = yy + i*QK_K + step; - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - const float d = (float)x[i].d[0]; - const float m = (float)x[i].d[1]; - float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) - + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) - + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) - + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); - } - tmp += sum; - } - -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { - - const int row = blockIdx.x; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - -#if QK_K == 256 - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; - - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 2; - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1 << (2*im); - const uint8_t hm2 = hm1 << 4; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - - for (int i = ix; i < num_blocks_per_row; i += 2) { - - const uint8_t * ql1 = x[i].qs + q_offset; - const uint8_t * ql2 = ql1 + 64; - const uint8_t * qh = x[i].qh + l0; - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = x[i].d; - const float dmin = x[i].dmin; - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - - float4 sum = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) - + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) - + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) - + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) - + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0)); - smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] - + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; - } - tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; - } - -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); - const int step = tid * K_QUANTS_PER_ITERATION; - const int im = step/8; - const int in = step%8; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - const uint8_t * q = x[i].qs + step; - const int8_t * s = x[i].scales; - const float * y = yy + i*QK_K + step; - const float d = x[i].d; - float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - const uint8_t h = x[i].qh[in+j] >> im; - sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) - + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) - + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) - + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); - } - tmp += sum; - } -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q6_K * x = (const block_q6_K *)vx + ib0; - -#if QK_K == 256 - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else - const int l0 = 4 * in; // 0, 4, 8, ..., 28 - const int is = in / 4; -#endif - const int ql_offset = 64*im + l0; - const int qh_offset = 32*im + l0; - const int s_offset = 8*im + is; - const int y_offset = 128*im + l0; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * ql = x[i].ql + ql_offset; - const uint8_t * qh = x[i].qh + qh_offset; - const int8_t * s = x[i].scales + s_offset; - - const float d = x[i].d; - -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else - float sum = 0; - for (int l = 0; l < 4; ++l) { - sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) - + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) - + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); - } - tmp += sum; -#endif - - } - -#else - - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 - - const int step = tid * K_QUANTS_PER_ITERATION; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + step; - const uint8_t * ql = x[i].ql + step; - const uint8_t * qh = x[i].qh + step; - const int8_t * s = x[i].scales; - - const float d = x[i+0].d; - - float sum = 0; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) - + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) - + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) - + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); - } - tmp += sum; - - } - -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - - // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; -} - -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - block_q8_1 * y = (block_q8_1 *) vy; - - const int ib = i / QK8_1; // block index - const int iqs = i % QK8_1; // quant index - - const float xi = i < ndata ? x[i] : 0.0f; - float amax = fabsf(xi); - float sum = xi; - -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); - } - - const float d = amax / 127; - const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); - - y[ib].qs[iqs] = q; - - if (iqs > 0) { - return; - } - - y[ib].d = d; - y[ib].s = sum; -} - -template -static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) { - const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; - - if (i >= k) { - return; - } - - const int ib = i/qk; // block index - const int iqs = (i%qk)/qr; // quant index - const int iybs = i - i%qk; // y block start index - const int y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - dfloat2 v; - dequantize_kernel(vx, ib, iqs, v); - - y[iybs + iqs + 0] = v.x; - y[iybs + iqs + y_offset] = v.y; -} - -static __device__ __forceinline__ float vec_dot_q4_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; - - int vi; - memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]); - - const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d); - - // subtract 8 from each quantized value - const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808); - const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808); - - // SIMD dot product of quantized values - int sumi = __dp4a(vi0, ui0, 0); - sumi = __dp4a(vi1, ui1, sumi); - - return sumi*d; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q4_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; - - const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]); - - const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d); - const float m = bq4_1->m; - const float s = bq8_1->s; - - const int vi0 = (vi >> 0) & 0x0F0F0F0F; - const int vi1 = (vi >> 4) & 0x0F0F0F0F; - - // SIMD dot product of quantized values - int sumi = __dp4a(vi0, ui0, 0); - sumi = __dp4a(vi1, ui1, sumi); - - return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q5_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; - - int qs; - memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); - const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2); - const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]); - - const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d); - - int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits - vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 - vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 - vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 - vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 - vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values - int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values - - int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits - vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 - vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 - vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 - vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 - vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values - sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values - - return sumi*d; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q5_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; - - const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); - const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2); - const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]); - - const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d); - const float m = bq5_1->m; - const float s = bq8_1->s; - - int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits - vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 - vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 - vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 - vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 - int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values - - int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits - vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 - vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 - vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 - vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 - sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values - - return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q8_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; - - int vi; - memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); - const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - - const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d); - - // SIMD dot product of quantized values - int sumi = __dp4a(vi, ui, 0); - - return sumi*d; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q2_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q2_K * bq2_K = (const block_q2_K *) vbq; - - const int bq8_offset = QR2_K * (iqs / QI8_1); - const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); - - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - const float d = bq2_K->d; - const float dmin = bq2_K->dmin; - - const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]); - - for (int i = 0; i < QR2_K; ++i) { - const int sc = bq2_K->scales[scale_offset + 2*i]; - - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const float d8i = bq8i->d; - - const int vi = (v >> (2*i)) & 0x03030303; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); - - sumf_d += d8i * (__dp4a(vi, ui, 0) * (sc & 0xF)); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values - } - - return d*sumf_d - dmin*sumf_m; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q3_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q3_K * bq3_K = (const block_q3_K *) vbq; - - const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); - const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); - - float sumf = 0.0f; - - const float d = bq3_K->d; - - int vl; - memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int)); - - int vh; - memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int)); - vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted - vh >>= bq8_offset; - - for (int i = 0; i < QR3_K; ++i) { - const int isc = scale_offset + 2*i; - - const int isc_low = isc % (QK_K/32); - const int sc_shift_low = 4 * (isc / (QK_K/32)); - const int sc_low = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF; - - const int isc_high = isc % (QK_K/64); - const int sc_shift_high = 2 * (isc / (QK_K/64)); - const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; - - const int sc = (sc_low | sc_high) - 32; - - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); - const float d8i = bq8i->d; - - const int vil = (vl >> (2*i)) & 0x03030303; - - const int vih = ((vh >> i) << 2) & 0x04040404; - - const int vi = __vsubss4(vil, vih); - - sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - } - - return d*sumf; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q4_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_K * bq4_K = (const block_q4_K *) vbq; - - const int bq8_offset = QR4_K * (iqs / QI8_1); - - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - const float d = bq4_K->d; - const float dmin = bq4_K->dmin; - - const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]); - - for (int i = 0; i < QR4_K; ++i) { - const int isc = bq8_offset + i; - - uint8_t sc, m; - get_scale_min_k4(isc, bq4_K->scales, sc, m); - - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); - const float d8i = bq8i->d; - - const int vi = (v >> (4*i)) & 0x0F0F0F0F; - - sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q4_K with sum of q8_1 values - } - - return d*sumf_d - dmin*sumf_m; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q5_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q5_K * bq5_K = (const block_q5_K *) vbq; - - const int bq8_offset = QR5_K * (iqs / QI8_1); - - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - const float d = bq5_K->d; - const float dmin = bq5_K->dmin; - - const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]); - - const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset; - - for (int i = 0; i < QR5_K; ++i) { - const int isc = bq8_offset + i; - - uint8_t sc, m; - get_scale_min_k4(isc, bq5_K->scales, sc, m); - - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); - const float d8i = bq8i->d; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> i) << 4) & 0x10101010; - - const int vi = vil | vih; - - sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q5_K with sum of q8_1 values - } - - return d*sumf_d - dmin*sumf_m; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q6_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q6_K * bq6_K = (const block_q6_K *) vbq; - - const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); - const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); - const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); - - float sumf = 0.0f; - - const float d = bq6_K->d; - - int vl; - memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int)); - - int vh; - memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int)); - - for (int i = 0; i < QR6_K; ++i) { - const int sc = bq6_K->scales[scale_offset + 4*i]; - - const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]); - const float d8i = bq8i->d; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030; - - const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 - - sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - } - - return d*sumf; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template -static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index - - const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx - - const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int - - tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); - } - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -template -static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - // qk = quantized weights per x block - // qr = number of quantized weights per data value in x block - const int row = blockIdx.y*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int tid = threadIdx.x; - - const int iter_stride = 2*GGML_CUDA_DMMV_X; - const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter - const int y_offset = qr == 1 ? 1 : qk/2; - -// partial sum for each thread -#ifdef GGML_CUDA_DMMV_F16 - half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics -#else - float tmp = 0.0f; -#endif // GGML_CUDA_DMMV_F16 - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int ib = (row*ncols + col)/qk; // x block index - const int iqs = (col%qk)/qr; // x quant index - const int iybs = col - col%qk; // y block start index - -// processing >2 values per i iter is faster for fast GPUs -#pragma unroll - for (int j = 0; j < vals_per_iter; j += 2) { - // process 2 vals per j iter - - // dequantize - // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val - dfloat2 v; - dequantize_kernel(vx, ib, iqs + j/qr, v); - - // matrix multiplication - // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 -#ifdef GGML_CUDA_DMMV_F16 - tmp += __hmul2(v, { - y[iybs + iqs + j/qr + 0], - y[iybs + iqs + j/qr + y_offset] - }); -#else - tmp += v.x * y[iybs + iqs + j/qr + 0]; - tmp += v.y * y[iybs + iqs + j/qr + y_offset]; -#endif // GGML_CUDA_DMMV_F16 +static std::unordered_map> g_cuda_stream_pools; +static size_t g_cuda_pool_size = 0; + +static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size, cudaStream_t stream) { + std::vector& pool = g_cuda_stream_pools[stream]; + + // find existing + for (size_t i = 0; i < pool.size(); ++i) { + cuda_pool_buffer& b = pool[i]; + if (b.size >= size && b.ptr != nullptr) { + void * ptr = b.ptr; + *actual_size = b.size; + pool.erase(pool.begin() + i); + return ptr; } } - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + // allocate new + void * ptr; + CUDA_CHECK(cudaMalloc(&ptr, size)); + *actual_size = size; - if (tid == 0) { -#ifdef GGML_CUDA_DMMV_F16 - dst[row] = tmp.x + tmp.y; -#else - dst[row] = tmp; -#endif // GGML_CUDA_DMMV_F16 - } + g_cuda_pool_size += size; + + //fprintf(stderr, "cuda pool size: %.2f MB (allocating now: %.2f MB)\n", g_cuda_pool_size / 1024.0 / 1024.0, size / 1024.0 / 1024.0); + + return ptr; } -static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { - const half * x = (const half *) vx; +static void ggml_cuda_pool_free(void * ptr, size_t size, cudaStream_t stream) { + std::vector& pool = g_cuda_stream_pools[stream]; - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; + pool.push_back({ ptr, size }); +} - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; +static void ggml_cuda_pool_free_all() { + for (auto& p : g_cuda_stream_pools) { + for (auto& b : p.second) { + if (b.ptr != nullptr) { + CUDA_CHECK(cudaFree(b.ptr)); + } } - - // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - - // y is not transposed but permuted - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // dst is not transposed and not permuted - const int idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[idst] = tmp; } + g_cuda_stream_pools.clear(); } -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - const int idst = channel*nrows_dst + row_dst; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - half * dsti = (half *) cdsti; - - *dsti = __float2half(*xi); -} - -template -static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, - const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= ne) { - return; - } - - // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor - // then combine those indices with the corresponding byte offsets to get the total offsets - const int i02 = i / (ne00*ne01); - const int i01 = (i - i02*ne01*ne00) / ne00; - const int i00 = i - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; - - const int i12 = i / (ne10*ne11); - const int i11 = (i - i12*ne10*ne11) / ne10; - const int i10 = i - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; - - cpy_1(cx + x_offset, cdst + dst_offset); -} - -// rope == RoPE == rotary positional embedding -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) { - const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); - - if (col >= ncols) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - - const float theta = p*powf(theta_scale, col/2); - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + 1]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + 1] = x0*sin_theta + x1*cos_theta; -} - -static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - const int half_n_dims = ncols/4; - - if (col >= half_n_dims) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - - const float col_theta_scale = powf(theta_scale, col); - - const float theta = p*col_theta_scale; - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + half_n_dims]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - - const float block_theta = block_p*col_theta_scale; - const float sin_block_theta = sinf(block_theta); - const float cos_block_theta = cosf(block_theta); - - const float x2 = x[i + half_n_dims * 2]; - const float x3 = x[i + half_n_dims * 3]; - - dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; - dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; -} - -static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - const int row = blockDim.y*blockIdx.y + threadIdx.y; - - if (col >= ncols) { - return; - } - - const int i = row*ncols + col; - // dst[i] = col > n_past + row ? -INFINITY : x[i]; - dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU -} - -// the CUDA soft max implementation differs from the CPU implementation -// instead of doubles floats are used -// values are also not normalized to the maximum value by subtracting it in the exponential function -// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine -static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int block_size = blockDim.x; - const int tid = threadIdx.x; - - float tmp = 0.0; - - for (int block_start = 0; block_start < ncols; block_start += block_size) { - const int col = block_start + tid; - - if (col >= ncols) { - break; - } - - const int i = row*ncols + col; - const float val = expf(x[i]); - tmp += val; - dst[i] = val; - } - - // sum up partial sums -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - for (int block_start = 0; block_start < ncols; block_start += block_size) { - const int col = block_start + tid; - - if (col >= ncols) { - break; - } - - const int i = row*ncols + col; - dst[i] /= tmp; - } -} - -static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = scale * x[i]; -} - -static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f32<<>>(x, y, dst, kx, ky); -} - -static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f16_f32_f16<<>>(x, y, dst, k); -} - -static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; - mul_f32<<>>(x, y, dst, kx, ky); -} - -static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; - gelu_f32<<>>(x, dst, k); -} - -static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; - silu_f32<<>>(x, dst, k); -} - -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); - const dim3 block_dims(WARP_SIZE, 1, 1); - norm_f32<<>>(x, dst, ncols); -} - -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols); -} - -static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { +template +static void quantize_row_q8_1_cuda(const src_t * x, void * vy, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - quantize_q8_1<<>>(x, vy, ndata, k); + quantize_q8_1<<>>(x, vy, k); } -static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block><<>>(vx, y, k); } -static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block><<>>(vx, y, k); } -static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block><<>>(vx, y, k); } -static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block><<>>(vx, y, k); } -static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block><<>>(vx, y, k); } +/* static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); -#else - dequantize_block_q2_K<<>>(vx, y); -#endif } static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); -#else - dequantize_block_q3_K<<>>(vx, y); -#endif } +template static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); @@ -2088,101 +211,100 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); -#else - dequantize_block_q5_K<<>>(vx, y); -#endif } -static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +*/ +template +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); -#else - dequantize_block_q6_K<<>>(vx, y); -#endif } -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } - +/* +template static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int ny = 2; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } +template static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols); } +template static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols); } +template static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const dim3 block_dims(32, 1, 1); dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); } +*/ -static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; @@ -2191,111 +313,73 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } -static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; +template +static void convert_mul_mat_vec_f16_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec><<>>(vx, y, dst, ncols, nrows); } -static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK4_1 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK5_0 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK5_1 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK8_0 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<1, 1, convert_f16><<>>(vx, y, k); -} - -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +template +static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec<1, 1, convert_f16> - <<>>(vx, y, dst, ncols, nrows); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); } -static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { +template +static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +template +static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +template +static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +template +static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +template +static void convert_fp16_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block><<>>(vx, y, k); +} + +template +static to_t_cuda_t ggml_get_to_t_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; @@ -2307,6 +391,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q5_1_cuda; case GGML_TYPE_Q8_0: return dequantize_row_q8_0_cuda; + /* case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: @@ -2315,223 +400,168 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q4_K_cuda; case GGML_TYPE_Q5_K: return dequantize_row_q5_K_cuda; + */ case GGML_TYPE_Q6_K: return dequantize_row_q6_K_cuda; case GGML_TYPE_F16: - return convert_fp16_to_fp32_cuda; + return convert_fp16_cuda; default: return nullptr; } } -static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { +template +static void ggml_mul_mat_p021_cuda(const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); + k_mul_mat_p021<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); } -static void ggml_mul_mat_vec_nc_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, +template +static void ggml_mul_mat_vec_nc_cuda( + const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); + k_mul_mat_vec_nc<<>> + (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x); } -static void ggml_cpy_f32_f32_cuda( +template +static void ggml_cpy_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> + k_cpy<<>> (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } -static void ggml_cpy_f32_f16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, - const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +template +static void add_cuda(const src0_t * x, const src1_t * y, dst_t * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + k_add<<>>(x, y, dst, k); } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { +template +static void mul_cuda(const src0_t * x, const src1_t * y, dst_t * dst, const int kx, const int ky, cudaStream_t stream) { + const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; + k_mul<<>>(x, y, dst, kx, ky); +} + +template +static void silu_cuda(const src0_t * x, dst_t * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + k_silu<<>>(x, dst, k); +} + +template +static void rms_norm_cuda(const src0_t * x, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + const dim3 block_dims(WARP_SIZE, 1, 1); + k_rms_norm<<>>(x, dst, ncols); +} + +template +static void scale_cuda(const src0_t * x, dst_t * dst, const src1_t * scale, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, k); + k_scale<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) { +template +static void rope_cuda(const src0_t * x, dst_t * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(num_blocks_x, nrows, 1); - rope_f32<<>>(x, dst, ncols, p, theta_scale); + k_rope<<>>(x, dst, ncols, p, theta_scale); } -static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { - GGML_ASSERT(nrows % 4 == 0); - const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1); - const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(num_blocks_x, nrows, 1); - rope_glm_f32<<>>(x, dst, ncols, p, block_p, theta_scale); -} - -static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { +template +static void diag_mask_inf_cuda(const src0_t * x, dst_t * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; const dim3 block_nums(block_num_x, nrows_x, 1); - diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); + k_diag_mask_inf<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { - const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums(1, nrows_x, 1); - soft_max_f32<<>>(x, dst, ncols_x); -} - -// buffer pool for cuda -#define MAX_CUDA_BUFFERS 256 - -struct scoped_spin_lock { - std::atomic_flag& lock; - scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { - while (lock.test_and_set(std::memory_order_acquire)) { - ; // spin +template +static void soft_max_cuda(const src0_t * x, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { + // TODO: implement fast numerically stable version for small ncols + //if (ncols >= 1024) { + int num_blocks = nrows; + if (ncols % 2 == 0) { + k_soft_max + <<>>(x, dst, nrows, ncols); } - } - ~scoped_spin_lock() { - lock.clear(std::memory_order_release); - } - scoped_spin_lock(const scoped_spin_lock&) = delete; - scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; -}; - -struct cuda_buffer { - void * ptr = nullptr; - size_t size = 0; -}; - -static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; -static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; - -static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { - scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - cuda_buffer& b = g_cuda_buffer_pool[id][i]; - if (b.size >= size && b.ptr != nullptr) { - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; + else { + k_soft_max + <<>>(x, dst, nrows, ncols); } - } - void * ptr; - CUDA_CHECK(cudaMalloc((void **) &ptr, size)); - *actual_size = size; - return ptr; + //} + //else { + // const dim3 block_dims(WARP_SIZE, 1, 1); + // const dim3 block_nums(1, nrows, 1); + // k_soft_max_orig<<>>(x, dst, ncols); + //} } -static void ggml_cuda_pool_free(void * ptr, size_t size) { - scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - cuda_buffer& b = g_cuda_buffer_pool[id][i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; - } - } - fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); - CUDA_CHECK(cudaFree(ptr)); +template dq> +static void get_rows_cuda(const void * x, const int * y, dst_t * dst, const int nrows, const int ncols, cudaStream_t stream) { + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); + const int block_num = (ncols/2 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(block_num, nrows, 1); + k_get_rows<<>>(x, y, dst, ncols); } - -static void * g_scratch_buffer = nullptr; -static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default -static size_t g_scratch_offset = 0; - -static int g_device_count = -1; -static int g_main_device = 0; +// TODO: move to context +static cublasHandle_t g_cublas_handle = nullptr; +static cudaStream_t g_cudaStream_main = nullptr; +static cudaEvent_t g_cudaEvent_main = nullptr; +static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_SUBSTREAMS] = { }; +static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_SUBSTREAMS] = { }; +#define GGML_CUDA_MAX_DEVICES 16 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; -static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; -static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - -static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr }; - -void ggml_init_cublas() { +static void ggml_init_cublas() { static bool initialized = false; if (!initialized) { - CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); - GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); int64_t total_vram = 0; - fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count); - for (int id = 0; id < g_device_count; ++id) { + fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, device_count); + for (int id = 0; id < device_count; ++id) { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); - - g_tensor_split[id] = total_vram; + fprintf(stderr, " Device %d: %s (%.0f GB)\n", id, prop.name, prop.totalGlobalMem / 1024.0 / 1024.0 / 1024.0); total_vram += prop.totalGlobalMem; - g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; } - for (int id = 0; id < g_device_count; ++id) { - g_tensor_split[id] /= total_vram; + + // create main stream and event + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream_main, cudaStreamNonBlocking)); + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent_main, cudaEventDisableTiming)); + + // create secondary streams and events + for (int i = 0; i < GGML_CUDA_MAX_SUBSTREAMS; ++i) { + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking)); + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming)); } - for (int id = 0; id < g_device_count; ++id) { - CUDA_CHECK(cudaSetDevice(id)); - - // create main stream - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking)); - - // create cublas handle - CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); - CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); - } + // create cublas handle + CUBLAS_CHECK(cublasCreate(&g_cublas_handle)); + //CUBLAS_CHECK(cublasSetMathMode(g_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + //CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); initialized = true; } } -void ggml_cuda_set_tensor_split(const float * tensor_split) { - bool all_zero = true; - for (int i = 0; i < g_device_count; ++i) { - if (tensor_split[i] != 0.0f) { - all_zero = false; - break; - } - } - if (all_zero) { - return; - } - float split_sum = 0.0f; - for (int i = 0; i < g_device_count; ++i) { - g_tensor_split[i] = split_sum; - split_sum += tensor_split[i]; - } - for (int i = 0; i < g_device_count; ++i) { - g_tensor_split[i] /= split_sum; - } -} - void * ggml_cuda_host_malloc(size_t size) { if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { return nullptr; @@ -2555,23 +585,396 @@ void ggml_cuda_host_free(void * ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } -static cudaError_t ggml_cuda_cpy_tensor_2d( - void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { +template +static void ggml_cuda_op_add( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { - cudaMemcpyKind kind; - char * src_ptr; - if (src->backend == GGML_BACKEND_CPU) { - kind = cudaMemcpyHostToDevice; - src_ptr = (char *) src->data; - } else if (src->backend == GGML_BACKEND_GPU) { - kind = cudaMemcpyDeviceToDevice; - struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; - int id; - CUDA_CHECK(cudaGetDevice(&id)); - src_ptr = (char *) extra->data_device[id]; - } else { - GGML_ASSERT(false); + const int64_t ne0 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + add_cuda((src0_t *)src0_d, (src1_t *) src1_d, (dst_t *) dst_d, ne0*i01_diff, stream); + CUDA_CHECK(cudaGetLastError()); + + UNUSED(src1); + UNUSED(dst); + UNUSED(i02); + UNUSED(i1); +} + +template +static void ggml_cuda_op_mul( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + for (int64_t i01 = i01_low; i01 < i01_high; i01++) { + const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0 + + src0_t * src0_d_i01 = (src0_t *) src0_d + i01*ne00; + src1_t * src1_d_i01 = (src1_t *) src1_d + i11*ne10; + dst_t * dst_d_i01 = (dst_t *) dst_d + i01*ne00; + + // compute + mul_cuda(src0_d_i01, src1_d_i01, dst_d_i01, ne00, ne10, stream); + CUDA_CHECK(cudaGetLastError()); } + + UNUSED(dst); + UNUSED(i02); +} + +template +static void ggml_cuda_op_silu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + silu_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00*i01_diff, stream); + CUDA_CHECK(cudaGetLastError()); + + UNUSED(src1); + UNUSED(src1_d); + UNUSED(dst); + UNUSED(i02); + UNUSED(i1); +} + +template +static void ggml_cuda_op_rms_norm( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + rms_norm_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, stream); + CUDA_CHECK(cudaGetLastError()); + + UNUSED(src1); + UNUSED(src1_d); + UNUSED(dst); + UNUSED(i02); + UNUSED(i1); +} + +template +static void ggml_cuda_op_dequantize_mul_mat_vec( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = i01_high - i01_low; + +#ifdef GGML_CUDA_FORCE_DMMV + const bool use_mul_mat_vec_q = false; +#else + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 || + src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || + src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0; + + // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6. + // However, they have bad performance with Pascal cards. + // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q. + const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented; +#endif + + if (use_mul_mat_vec_q) { + size_t as; + void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as, stream); + quantize_row_q8_1_cuda((src1_t *)src1_d, src1_q8_1, ne00, stream); + + switch (src0->type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); + break; + default: + GGML_ASSERT(false); + break; + } + + ggml_cuda_pool_free(src1_q8_1, as, stream); + } + else { + switch (src0->type) { + case GGML_TYPE_Q4_0: + dequantize_mul_mat_vec_q4_0_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q4_1: + dequantize_mul_mat_vec_q4_1_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q5_0: + dequantize_mul_mat_vec_q5_0_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q5_1: + dequantize_mul_mat_vec_q5_1_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_Q8_0: + dequantize_mul_mat_vec_q8_0_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); + break; + /* + case GGML_TYPE_Q2_K: + dequantize_mul_mat_vec_q2_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q3_K: + dequantize_mul_mat_vec_q3_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q4_K: + dequantize_mul_mat_vec_q4_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q5_K: + dequantize_mul_mat_vec_q5_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); + break; + */ + case GGML_TYPE_Q6_K: + dequantize_mul_mat_vec_q6_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); + break; + case GGML_TYPE_F16: + convert_mul_mat_vec_f16_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); + break; + default: + GGML_ASSERT(false); + break; + } + } + CUDA_CHECK(cudaGetLastError()); + + UNUSED(src1); + UNUSED(dst); + UNUSED(i02); + UNUSED(i1); +} + +template +static void ggml_cuda_op_rope( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + const int n_past = ((int32_t *) dst->params)[0]; + const int n_dims = ((int32_t *) dst->params)[1]; + const int mode = ((int32_t *) dst->params)[2]; + //const int n_ctx = ((int32_t *) dst->params)[3]; + GGML_ASSERT(mode == 0); + + const float theta_scale = powf(10000.0, -2.0f/n_dims); + const float p = ((mode & 1) == 0 ? n_past + i02 : i02); + + // compute + rope_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, p, theta_scale, stream); + CUDA_CHECK(cudaGetLastError()); + + UNUSED(dst); + UNUSED(src1); + UNUSED(src1_d); + UNUSED(i1); +} + +template +static void ggml_cuda_op_diag_mask_inf( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t i01_diff = i01_high - i01_low; + + const int n_past = ((int32_t *) dst->params)[0]; + + // compute + diag_mask_inf_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, ne01, n_past, stream); + CUDA_CHECK(cudaGetLastError()); + + UNUSED(dst); + UNUSED(src1); + UNUSED(src1_d); + UNUSED(i02); + UNUSED(i1); +} + +template +static void ggml_cuda_op_soft_max( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + soft_max_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, stream); + CUDA_CHECK(cudaGetLastError()); + + UNUSED(src1); + UNUSED(src1_d); + UNUSED(dst); + UNUSED(i02); + UNUSED(i1); +} + +template +static void ggml_cuda_op_scale( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + //const src1_t scale = ((src1_t *) src1->data)[0]; + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + scale_cuda((src0_t *)src0_d, (dst_t *)dst_d, (src1_t *)src1_d, ne00*i01_diff, stream); + CUDA_CHECK(cudaGetLastError()); + + UNUSED(src1); + UNUSED(src1_d); + UNUSED(dst); + UNUSED(i02); + UNUSED(i1); +} + +template +static void ggml_cuda_op_get_rows( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + void * src0_d, void * src1_d, void * dst_d, + int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int ncols = src0->ne[0]; + const int nrows = ggml_nelements(src1); + + switch (src0->type) { + case GGML_TYPE_F16: + get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_F32: + get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q4_0: + get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q4_1: + get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q5_0: + get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q5_1: + get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q8_0: + get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); + break; + + default: + GGML_ASSERT(false); + break; + } + CUDA_CHECK(cudaGetLastError()); + + UNUSED(i02); + UNUSED(i01_low); + UNUSED(i01_high); + UNUSED(i1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ggml_cuda_buffer { + const char * name; + + void * data; + size_t size; + void * device; +}; + +struct ggml_cuda_context { + std::vector buffers; +}; + +ggml_cuda_context * ggml_cuda_init() { + ggml_init_cublas(); + + ggml_cuda_context * ctx = new ggml_cuda_context; + + return ctx; +} + +void ggml_cuda_free(ggml_cuda_context * ctx) { + for (size_t n = 0; n < ctx->buffers.size(); ++n) { + if (ctx->buffers[n].device != nullptr) { + CUDA_CHECK(cudaFree(ctx->buffers[n].device)); + } + } + + // this will free the global memory pool for all contexts + ggml_cuda_pool_free_all(); + + delete ctx; +} + +static void * ggml_cuda_get_buffer(ggml_cuda_context * ctx, ggml_tensor * t) { + return t->data; + + UNUSED(ctx); +} + +static cudaError_t ggml_cuda_cpy_tensor_2d(ggml_cuda_context * ctx, + void * dst, ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { + + cudaMemcpyKind kind = cudaMemcpyDeviceToDevice; + const char * src_ptr = (const char *) ggml_cuda_get_buffer(ctx, src); char * dst_ptr = (char *) dst; const int64_t ne0 = src->ne[0]; @@ -2584,6 +987,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const int64_t bs = ggml_blck_size(type); int64_t i1_diff = i1_high - i1_low; + GGML_ASSERT(i1_low == 0); const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; if (nb0 == ts && nb1 == ts*ne0/bs) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); @@ -2601,450 +1005,52 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( } } -inline void ggml_cuda_op_add( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ +static const ggml_type GGML_TYPE_NONE = GGML_TYPE_COUNT; - GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr); - GGML_ASSERT(src1_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); +struct ggml_cuda_op_dispatch_t { + ggml_cuda_op_t d[GGML_TYPE_COUNT][GGML_TYPE_COUNT+1][GGML_TYPE_COUNT] = { nullptr }; +}; - const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; +template