From 569550df20c1ede59ff195a6b6e900957ad84d16 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Fri, 29 Sep 2023 07:15:57 -0400 Subject: [PATCH 01/12] readme : add link to grammars app (#3388) * Add link to grammars app per @ggernagov suggestion Adding a sentence in the Grammars section of README to point to grammar app, per https://github.com/ggerganov/llama.cpp/discussions/2494#discussioncomment-7138211 * Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 9675ce1e7..8cdfb04e0 100644 --- a/README.md +++ b/README.md @@ -662,6 +662,8 @@ PROMPT_TEMPLATE=./prompts/chat-with-bob.txt PROMPT_CACHE_FILE=bob.prompt.bin \ The `grammars/` folder contains a handful of sample grammars. To write your own, check out the [GBNF Guide](./grammars/README.md). +For authoring more complex JSON grammars, you can also check out https://grammar.intrinsiclabs.ai/, a browser app that lets you write TypeScript interfaces which it compiles to GBNF grammars that you can save for local use. Note that the app is built and maintained by members of the community, please file any issues or FRs on [its repo](http://github.com/intrinsiclabsai/gbnfgen) and not this one. + ### Instruction mode with Alpaca 1. First, download the `ggml` Alpaca model into the `./models` folder From 0a4a4a098261ddd26480371eaccfe90d1bf6488a Mon Sep 17 00:00:00 2001 From: BarfingLemurs <128182951+BarfingLemurs@users.noreply.github.com> Date: Fri, 29 Sep 2023 08:50:35 -0400 Subject: [PATCH 02/12] readme : update hot topics + model links (#3399) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8cdfb04e0..75b6075f2 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics -- Parallel decoding + continuous batching support incoming: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \ +- Parallel decoding + continuous batching support added: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \ **Devs should become familiar with the new API** - Local Falcon 180B inference on Mac Studio @@ -92,7 +92,8 @@ as the main playground for developing new features for the [ggml](https://github - [X] [WizardLM](https://github.com/nlpxucan/WizardLM) - [X] [Baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) and its derivations (such as [baichuan-7b-sft](https://huggingface.co/hiyouga/baichuan-7b-sft)) - [X] [Aquila-7B](https://huggingface.co/BAAI/Aquila-7B) / [AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) -- [X] Mistral AI v0.1 +- [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187) +- [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) **Bindings:** From 2777a84be429401a2b7d33c2b6a4ada1f0776f1b Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Fri, 29 Sep 2023 09:48:45 -0400 Subject: [PATCH 03/12] llama : quantize up to 31% faster on Linux and Windows with mmap (#3206) * llama : enable mmap in quantize on Linux -> 31% faster * also enable mmap on Windows --------- Co-authored-by: Georgi Gerganov --- llama.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 666acc212..bff17135b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6027,7 +6027,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } - llama_model_loader ml(fname_inp, /*use_mmap*/ false); + // mmap consistently increases speed Linux, and also increases speed on Windows with + // hot cache. It may cause a slowdown on macOS, possibly related to free memory. +#if defined(__linux__) || defined(_WIN32) + constexpr bool use_mmap = true; +#else + constexpr bool use_mmap = false; +#endif + + llama_model_loader ml(fname_inp, use_mmap); + if (ml.use_mmap) { + ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa())); + } llama_model model; llm_load_arch(ml, model); @@ -6105,10 +6116,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const std::string name = ggml_get_name(tensor); - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); + if (!ml.use_mmap) { + if (read_data.size() < ggml_nbytes(tensor)) { + read_data.resize(ggml_nbytes(tensor)); + } + tensor->data = read_data.data(); } - tensor->data = read_data.data(); ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", From bc34dd4f5b5a7c10ae3ed85a265ce6f2ed2fab79 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 Sep 2023 19:05:18 +0300 Subject: [PATCH 04/12] train : fix KQ_pos allocation (#3392) * train : fix KQ_pos allocation * make sure KQ_pos is not reallocated in finetune --------- Co-authored-by: xaedes --- examples/finetune/finetune.cpp | 5 ++++- examples/train-text-from-scratch/train-text-from-scratch.cpp | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index b61165fb7..8ca1874da 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -626,7 +626,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); - { + ggml_allocr_alloc(alloc, KQ_pos); + if (!ggml_allocr_is_measure(alloc)) { int * data = (int *) KQ_pos->data; for (int i = 0; i < N; ++i) { data[i] = n_past + i; @@ -786,6 +787,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL); ggml_allocr_alloc(alloc, t36->grad); + // KQ_pos + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one)); // make sure base model tensors data cannot be used in viewable operations ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one)); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 5043f32d0..be693b3ac 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -334,7 +334,8 @@ static struct ggml_tensor * llama_build_train_graphs( // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); - { + ggml_allocr_alloc(alloc, KQ_pos); + if (!ggml_allocr_is_measure(alloc)) { int * data = (int *) KQ_pos->data; for (int i = 0; i < N; ++i) { data[i] = n_past + i; From 40e07a60f9ce06e79f3ccd4c903eba300fb31b5e Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 29 Sep 2023 18:42:32 +0200 Subject: [PATCH 05/12] llama.cpp : add documentation about rope_freq_base and scale values (#3401) * llama.cpp : add documentation about rope_freq_base and scale values * add notice to hot topics --- README.md | 1 + llama.h | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 75b6075f2..ec7b58943 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics +- ‼️ Breaking change: `rope_freq_base` and `rope_freq_scale` must be set to zero to use the model default values: [#3401](https://github.com/ggerganov/llama.cpp/pull/3401) - Parallel decoding + continuous batching support added: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \ **Devs should become familiar with the new API** - Local Falcon 180B inference on Mac Studio diff --git a/llama.h b/llama.h index 96ff1f09c..fde4d6eca 100644 --- a/llama.h +++ b/llama.h @@ -167,18 +167,18 @@ extern "C" { struct llama_context_params { uint32_t seed; // RNG seed, -1 for random - uint32_t n_ctx; // text context - uint32_t n_batch; // prompt processing batch size + uint32_t n_ctx; // text context, 0 = from model + uint32_t n_batch; // prompt processing maximum batch size uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing // ref: https://github.com/ggerganov/llama.cpp/pull/2054 - float rope_freq_base; // RoPE base frequency - float rope_freq_scale; // RoPE frequency scaling factor + float rope_freq_base; // RoPE base frequency, 0 = from model + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels - bool f16_kv; // use fp16 for KV cache + bool f16_kv; // use fp16 for KV cache, fp32 otherwise bool logits_all; // the llama_eval() call computes all logits, not just the last one bool embedding; // embedding mode only }; From f5ef5cfb18148131fcf45bdd2331f0db5ab7c3d0 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 30 Sep 2023 18:12:57 +0200 Subject: [PATCH 06/12] ggml-cuda : perform cublas mat mul of quantized types as f16 (#3412) * ggml-cuda : perform cublas matrix multiplication of quantized types as fp16 * rename CC_TURING to CC_VOLTA * disable fp16 mat mul completely with multi GPU --- ggml-cuda.cu | 194 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 122 insertions(+), 72 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 86d1fe203..989c419cd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -80,9 +80,9 @@ #include "ggml.h" #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define CC_TURING 700 +#define CC_VOLTA 700 #define CC_OFFSET_AMD 1000000 -#define CC_RDNA2 CC_OFFSET_AMD + 1030 +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -715,7 +715,8 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in //================================== k-quants -static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; const block_q2_K * x = (const block_q2_K *) vx; @@ -727,7 +728,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float const int is = 8*n + l/16; const uint8_t q = x[i].qs[32*n + l]; - float * y = yy + i*QK_K + 128*n; + dst_t * y = yy + i*QK_K + 128*n; float dall = __low2half(x[i].dm); float dmin = __high2half(x[i].dm); @@ -739,7 +740,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float const int is = tid/16; // 0 or 1 const int il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); - float * y = yy + i*QK_K + 16*is + il; + dst_t * y = yy + i*QK_K + 16*is + il; float dall = __low2half(x[i].dm); float dmin = __high2half(x[i].dm); y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); @@ -748,7 +749,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float } -static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; @@ -772,7 +774,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float float d_all = x[i].d; float dl = d_all * (us - 32); - float * y = yy + i*QK_K + 128*n + 32*j; + dst_t * y = yy + i*QK_K + 128*n + 32*j; const uint8_t * q = x[i].qs + 32*n; const uint8_t * hm = x[i].hmask; @@ -784,7 +786,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float const int im = il/8; // 0...1 const int in = il%8; // 0...7 - float * y = yy + i*QK_K + 16*is + il; + dst_t * y = yy + i*QK_K + 16*is + il; const uint8_t q = x[i].qs[il] >> (2*is); const uint8_t h = x[i].hmask[in] >> (2*is + im); @@ -812,7 +814,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t } #endif -static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q4_K * x = (const block_q4_K *) vx; const int i = blockIdx.x; @@ -825,7 +828,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float const int is = 2*il; const int n = 4; - float * y = yy + i*QK_K + 64*il + n*ir; + dst_t * y = yy + i*QK_K + 64*il + n*ir; const float dall = __low2half(x[i].dm); const float dmin = __high2half(x[i].dm); @@ -844,7 +847,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float #else const int tid = threadIdx.x; const uint8_t * q = x[i].qs; - float * y = yy + i*QK_K; + dst_t * y = yy + i*QK_K; const float d = (float)x[i].dm[0]; const float m = (float)x[i].dm[1]; y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); @@ -852,7 +855,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float #endif } -static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q5_K * x = (const block_q5_K *) vx; const int i = blockIdx.x; @@ -864,7 +868,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float const int ir = tid%16; // ir is in 0...15 const int is = 2*il; // is is in 0...6 - float * y = yy + i*QK_K + 64*il + 2*ir; + dst_t * y = yy + i*QK_K + 64*il + 2*ir; const float dall = __low2half(x[i].dm); const float dmin = __high2half(x[i].dm); @@ -892,13 +896,14 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float const int is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; const float d = x[i].d; - float * y = yy + i*QK_K + tid; + dst_t * y = yy + i*QK_K + tid; y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); #endif } -static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q6_K * x = (const block_q6_K *) vx; const int i = blockIdx.x; @@ -910,7 +915,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float const int il = tid - 32*ip; // 0...32 const int is = 8*ip + il/16; - float * y = yy + i*QK_K + 128*ip + il; + dst_t * y = yy + i*QK_K + 128*ip + il; const float d = x[i].d; @@ -929,7 +934,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float const int ip = tid/16; // 0 or 1 const int il = tid - 16*ip; // 0...15 - float * y = yy + i*QK_K + 16*ip + il; + dst_t * y = yy + i*QK_K + 16*ip + il; const float d = x[i].d; @@ -3548,7 +3553,7 @@ template static __global__ void load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q4_0_AMPERE; const int mmq_y = MMQ_Y_Q4_0_AMPERE; const int nwarps = NWARPS_Q4_0_AMPERE; @@ -3568,7 +3573,7 @@ template static __global__ void #else (void) vec_dot_q4_0_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q4_1_RDNA2 64 @@ -3589,9 +3594,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q4_1( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -3611,7 +3616,7 @@ template static __global__ void load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q4_1_AMPERE; const int mmq_y = MMQ_Y_Q4_1_AMPERE; const int nwarps = NWARPS_Q4_1_AMPERE; @@ -3631,7 +3636,7 @@ template static __global__ void #else (void) vec_dot_q4_1_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q5_0_RDNA2 64 @@ -3672,7 +3677,7 @@ template static __global__ void load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q5_0_AMPERE; const int mmq_y = MMQ_Y_Q5_0_AMPERE; const int nwarps = NWARPS_Q5_0_AMPERE; @@ -3692,7 +3697,7 @@ template static __global__ void #else (void) vec_dot_q5_0_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q5_1_RDNA2 64 @@ -3733,7 +3738,7 @@ mul_mat_q5_1( load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q5_1_AMPERE; const int mmq_y = MMQ_Y_Q5_1_AMPERE; const int nwarps = NWARPS_Q5_1_AMPERE; @@ -3753,7 +3758,7 @@ mul_mat_q5_1( #else (void) vec_dot_q5_1_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q8_0_RDNA2 64 @@ -3794,7 +3799,7 @@ template static __global__ void load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q8_0_AMPERE; const int mmq_y = MMQ_Y_Q8_0_AMPERE; const int nwarps = NWARPS_Q8_0_AMPERE; @@ -3814,7 +3819,7 @@ template static __global__ void #else (void) vec_dot_q8_0_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q2_K_RDNA2 64 @@ -3855,7 +3860,7 @@ mul_mat_q2_K( load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q2_K_AMPERE; const int mmq_y = MMQ_Y_Q2_K_AMPERE; const int nwarps = NWARPS_Q2_K_AMPERE; @@ -3875,7 +3880,7 @@ mul_mat_q2_K( #else (void) vec_dot_q2_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q3_K_RDNA2 128 @@ -3896,9 +3901,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q3_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -3918,7 +3923,7 @@ template static __global__ void load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q3_K_AMPERE; const int mmq_y = MMQ_Y_Q3_K_AMPERE; const int nwarps = NWARPS_Q3_K_AMPERE; @@ -3938,7 +3943,7 @@ template static __global__ void #else (void) vec_dot_q3_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q4_K_RDNA2 64 @@ -3959,9 +3964,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q4_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -3981,7 +3986,7 @@ template static __global__ void load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q4_K_AMPERE; const int mmq_y = MMQ_Y_Q4_K_AMPERE; const int nwarps = NWARPS_Q4_K_AMPERE; @@ -4001,7 +4006,7 @@ template static __global__ void #else (void) vec_dot_q4_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q5_K_RDNA2 64 @@ -4042,7 +4047,7 @@ mul_mat_q5_K( load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q5_K_AMPERE; const int mmq_y = MMQ_Y_Q5_K_AMPERE; const int nwarps = NWARPS_Q5_K_AMPERE; @@ -4062,7 +4067,7 @@ mul_mat_q5_K( #else (void) vec_dot_q5_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } #define MMQ_X_Q6_K_RDNA2 64 @@ -4083,9 +4088,9 @@ template static __global__ void #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2) #endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_TURING +#elif __CUDA_ARCH__ < CC_VOLTA __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_TURING +#endif // __CUDA_ARCH__ < CC_VOLTA mul_mat_q6_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -4105,7 +4110,7 @@ template static __global__ void load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#elif __CUDA_ARCH__ >= CC_TURING +#elif __CUDA_ARCH__ >= CC_VOLTA const int mmq_x = MMQ_X_Q6_K_AMPERE; const int mmq_y = MMQ_Y_Q6_K_AMPERE; const int nwarps = NWARPS_Q6_K_AMPERE; @@ -4125,7 +4130,7 @@ template static __global__ void #else (void) vec_dot_q6_K_q8_1_mul_mat; assert(false); -#endif // __CUDA_ARCH__ >= CC_TURING +#endif // __CUDA_ARCH__ >= CC_VOLTA } template @@ -4604,32 +4609,38 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con quantize_q8_1<<>>(x, vy, kx, kx_padded); } -static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); @@ -4638,7 +4649,8 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu #endif } -static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); @@ -4647,12 +4659,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu #endif } -static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); } -static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); @@ -4661,7 +4675,8 @@ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cu #endif } -static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +template +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); @@ -4868,6 +4883,26 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_row_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_row_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_row_q8_0_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; case GGML_TYPE_F32: return convert_fp32_to_fp16_cuda; default: @@ -4921,7 +4956,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( mmq_x = MMQ_X_Q4_0_RDNA1; mmq_y = MMQ_Y_Q4_0_RDNA1; nwarps = NWARPS_Q4_0_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q4_0_AMPERE; mmq_y = MMQ_Y_Q4_0_AMPERE; nwarps = NWARPS_Q4_0_AMPERE; @@ -4966,7 +5001,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( mmq_x = MMQ_X_Q4_1_RDNA1; mmq_y = MMQ_Y_Q4_1_RDNA1; nwarps = NWARPS_Q4_1_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q4_1_AMPERE; mmq_y = MMQ_Y_Q4_1_AMPERE; nwarps = NWARPS_Q4_1_AMPERE; @@ -5011,7 +5046,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( mmq_x = MMQ_X_Q5_0_RDNA1; mmq_y = MMQ_Y_Q5_0_RDNA1; nwarps = NWARPS_Q5_0_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q5_0_AMPERE; mmq_y = MMQ_Y_Q5_0_AMPERE; nwarps = NWARPS_Q5_0_AMPERE; @@ -5056,7 +5091,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( mmq_x = MMQ_X_Q5_1_RDNA1; mmq_y = MMQ_Y_Q5_1_RDNA1; nwarps = NWARPS_Q5_1_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q5_1_AMPERE; mmq_y = MMQ_Y_Q5_1_AMPERE; nwarps = NWARPS_Q5_1_AMPERE; @@ -5101,7 +5136,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( mmq_x = MMQ_X_Q8_0_RDNA1; mmq_y = MMQ_Y_Q8_0_RDNA1; nwarps = NWARPS_Q8_0_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q8_0_AMPERE; mmq_y = MMQ_Y_Q8_0_AMPERE; nwarps = NWARPS_Q8_0_AMPERE; @@ -5146,7 +5181,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( mmq_x = MMQ_X_Q2_K_RDNA1; mmq_y = MMQ_Y_Q2_K_RDNA1; nwarps = NWARPS_Q2_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q2_K_AMPERE; mmq_y = MMQ_Y_Q2_K_AMPERE; nwarps = NWARPS_Q2_K_AMPERE; @@ -5193,7 +5228,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( mmq_x = MMQ_X_Q3_K_RDNA1; mmq_y = MMQ_Y_Q3_K_RDNA1; nwarps = NWARPS_Q3_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q3_K_AMPERE; mmq_y = MMQ_Y_Q3_K_AMPERE; nwarps = NWARPS_Q3_K_AMPERE; @@ -5239,7 +5274,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( mmq_x = MMQ_X_Q4_K_RDNA1; mmq_y = MMQ_Y_Q4_K_RDNA1; nwarps = NWARPS_Q4_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q4_K_AMPERE; mmq_y = MMQ_Y_Q4_K_AMPERE; nwarps = NWARPS_Q4_K_AMPERE; @@ -5284,7 +5319,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( mmq_x = MMQ_X_Q5_K_RDNA1; mmq_y = MMQ_Y_Q5_K_RDNA1; nwarps = NWARPS_Q5_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q5_K_AMPERE; mmq_y = MMQ_Y_Q5_K_AMPERE; nwarps = NWARPS_Q5_K_AMPERE; @@ -5329,7 +5364,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( mmq_x = MMQ_X_Q6_K_RDNA1; mmq_y = MMQ_Y_Q6_K_RDNA1; nwarps = NWARPS_Q6_K_RDNA1; - } else if (compute_capability >= CC_TURING) { + } else if (compute_capability >= CC_VOLTA) { mmq_x = MMQ_X_Q6_K_AMPERE; mmq_y = MMQ_Y_Q6_K_AMPERE; nwarps = NWARPS_Q6_K_AMPERE; @@ -5907,7 +5942,7 @@ static int64_t get_row_rounding(ggml_type type) { switch(type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - return max_compute_capability >= CC_TURING ? 128 : 64; + return max_compute_capability >= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: @@ -5918,7 +5953,7 @@ static int64_t get_row_rounding(ggml_type type) { case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - return max_compute_capability >= CC_TURING ? 128 : 64; + return max_compute_capability >= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: return 64; default: @@ -6083,8 +6118,19 @@ inline void ggml_cuda_op_mul_mat_cublas( const int compute_capability = g_compute_capabilities[id]; - if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) { - // convert src1 to fp16, multiply as fp16, convert dst to fp32 + if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = row_diff*ne00; + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); + to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); + } + const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; + half * src1_as_f16 = nullptr; size_t src1_as = 0; if (src1->type != GGML_TYPE_F16) { @@ -6106,9 +6152,9 @@ inline void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK( cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, - &alpha_f16, src0_dd_i, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16, CUDA_R_16F, ldc, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16, CUDA_R_16F, ldc, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -6117,6 +6163,10 @@ inline void ggml_cuda_op_mul_mat_cublas( ggml_cuda_pool_free(dst_f16, dst_as); + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + if (src1_as != 0) { ggml_cuda_pool_free(src1_as_f16, src1_as); } From c97f01c362ac102c6994edb80008f8608539553a Mon Sep 17 00:00:00 2001 From: vvhg1 <94630311+vvhg1@users.noreply.github.com> Date: Mon, 2 Oct 2023 09:42:02 +0200 Subject: [PATCH 07/12] infill : add new example + extend server API (#3296) * vvhg-code-infill (#1) * infill in separate example (#2) * reverted changes to main and added infill example * cleanup * naming improvement * make : add missing blank line * fix missing semicolon * brought infill up to current main code * cleanup --------- Co-authored-by: Cebtenzzre --- .gitignore | 1 + Makefile | 5 +- common/common.cpp | 2 + common/common.h | 1 + examples/infill/CMakeLists.txt | 8 + examples/infill/README.md | 41 ++ examples/infill/infill.cpp | 769 +++++++++++++++++++++++++++++++++ examples/server/README.md | 10 + examples/server/server.cpp | 206 +++++++++ llama.cpp | 20 + llama.h | 5 + 11 files changed, 1067 insertions(+), 1 deletion(-) create mode 100644 examples/infill/CMakeLists.txt create mode 100644 examples/infill/README.md create mode 100644 examples/infill/infill.cpp diff --git a/.gitignore b/.gitignore index f98132a22..a552139f1 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ models-mnt /embedding /gguf /gguf-llama-simple +/infill /libllama.so /llama-bench /main diff --git a/Makefile b/Makefile index 08b83ca7e..91198c555 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative benchmark-matmult parallel finetune export-lora tests/test-c.o +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o # Binaries only useful for tests TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama @@ -543,6 +543,9 @@ main: examples/main/main.cpp build-info.h ggml. @echo '==== Run ./main -h for help. ====' @echo +infill: examples/infill/infill.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/common/common.cpp b/common/common.cpp index ec181c6b3..4b233786a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -389,6 +389,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--infill") { + params.infill = true; } else if (arg == "--multiline-input") { params.multiline_input = true; } else if (arg == "--simple-io") { diff --git a/common/common.h b/common/common.h index 0e2d3fa6c..e095c56e3 100644 --- a/common/common.h +++ b/common/common.h @@ -120,6 +120,7 @@ struct gpt_params { bool use_mlock = false; // use mlock to keep model in memory bool numa = false; // attempt optimizations that help on some NUMA systems bool verbose_prompt = false; // print prompt tokens before generation + bool infill = false; // use infill mode }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/infill/CMakeLists.txt b/examples/infill/CMakeLists.txt new file mode 100644 index 000000000..046f9b1e7 --- /dev/null +++ b/examples/infill/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET infill) +add_executable(${TARGET} infill.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/infill/README.md b/examples/infill/README.md new file mode 100644 index 000000000..8c97f719b --- /dev/null +++ b/examples/infill/README.md @@ -0,0 +1,41 @@ +# llama.cpp/example/infill + +This example shows how to use the infill mode with Code Llama models supporting infill mode. +Currently the 7B and 13B models support infill mode. + +Infill supports most of the options available in the main example. + +For further information have a look at the main README.md in llama.cpp/example/main/README.md + +## Common Options + +In this section, we cover the most commonly used options for running the `infill` program with the LLaMA models: + +- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). +- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. +- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text. +- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. + +## Input Prompts + +The `infill` program provides several ways to interact with the LLaMA models using input prompts: + +- `--in-prefix PROMPT_BEFORE_CURSOR`: Provide the prefix directly as a command-line option. +- `--in-suffix PROMPT_AFTER_CURSOR`: Provide the suffix directly as a command-line option. +- `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.) + +## Interaction + +The `infill` program offers a seamless way to interact with LLaMA models, allowing users to receive real-time infill suggestions. The interactive mode can be triggered using `--interactive`, and `--interactive-first` + +### Interaction Options + +- `-i, --interactive`: Run the program in interactive mode, allowing users to get real time code suggestions from model. +- `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation. +- `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text. + +### Example + +```bash +./infill -t 10 -ngl 0 -m models/codellama-13b.Q5_K_S.gguf -c 4096 --temp 0.7 --repeat_penalty 1.1 -n 20 --in-prefix "def helloworld():\n print(\"hell" --in-suffix "\n print(\"goodbye world\")\n " +``` diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp new file mode 100644 index 000000000..9ec75ce42 --- /dev/null +++ b/examples/infill/infill.cpp @@ -0,0 +1,769 @@ +#include "common.h" + +#include "console.h" +#include "llama.h" +#include "build-info.h" +#include "grammar-parser.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static llama_context ** g_ctx; +static llama_model ** g_model; +static gpt_params * g_params; +static std::vector * g_input_tokens; +static std::ostringstream * g_output_ss; +static std::vector * g_output_tokens; +static bool is_interacting = false; + + +static void write_logfile( + const llama_context * ctx, const gpt_params & params, const llama_model * model, + const std::vector & input_tokens, const std::string & output, + const std::vector & output_tokens +) { + if (params.logdir.empty()) { + return; + } + + const std::string timestamp = get_sortable_timestamp(); + + const bool success = create_directory_with_parents(params.logdir); + if (!success) { + fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", + __func__, params.logdir.c_str()); + return; + } + + const std::string logfile_path = params.logdir + timestamp + ".yml"; + FILE * logfile = fopen(logfile_path.c_str(), "w"); + + if (logfile == NULL) { + fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); + return; + } + + fprintf(logfile, "binary: infill\n"); + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc); + + fprintf(logfile, "\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "# Generation Results #\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "\n"); + + dump_string_yaml_multiline(logfile, "output", output.c_str()); + dump_vector_int_yaml(logfile, "output_tokens", output_tokens); + + llama_dump_timing_info_yaml(logfile, ctx); + fclose(logfile); +} + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +static void sigint_handler(int signo) { + if (signo == SIGINT) { + if (!is_interacting) { + is_interacting = true; + } else { + console::cleanup(); + printf("\n"); + llama_print_timings(*g_ctx); + write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); + _exit(130); + } + } +} +#endif + +int main(int argc, char ** argv) { + gpt_params params; + g_params = ¶ms; + + if (!gpt_params_parse(argc, argv, params)) { + return 1; + } + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("infill", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); +#endif // LOG_DISABLE_LOGS + + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); + + if (params.logits_all) { + printf("\n************\n"); + printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); + printf("************\n\n"); + + return 0; + } + + if (params.embedding) { + printf("\n************\n"); + printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__); + printf("************\n\n"); + + return 0; + } + + if (params.n_ctx != 0 && params.n_ctx < 8) { + LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; + } + if (params.instruct) { + printf("\n************\n"); + printf("%s: please use the 'main' tool for instruct mode\n", __func__); + printf("************\n\n"); + + return 0; + } + if (!params.antiprompt.empty()) { + printf("\n************\n"); + printf("%s: please use the 'main' tool for antiprompt mode\n", __func__); + printf("************\n\n"); + + return 0; + } + if (!params.interactive_first && (params.input_prefix.empty() && params.input_suffix.empty())) { + printf("\n************\n"); + printf("%s: please use '--interactive_first' or specify '--in_prefix' and/or '--in_suffix'\n", __func__); + printf("************\n\n"); + + return 0; + } + if (params.random_prompt) { + printf("\n************\n"); + printf("%s: please use the 'main' tool for random prompt mode\n", __func__); + printf("************\n\n"); + + return 0; + } + if (!params.path_prompt_cache.empty()) { + printf("\n************\n"); + printf("%s: infill does not support prompt caching\n", __func__); + printf("************\n\n"); + + return 0; + } + + if (params.rope_freq_base != 0.0) { + LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); + } + + if (params.rope_freq_scale != 0.0) { + LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); + } + + LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); + LOG_TEE("%s: built with %s for %s\n", __func__, BUILD_COMPILER, BUILD_TARGET); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + LOG_TEE("%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + + LOG("%s: llama backend init\n", __func__); + llama_backend_init(params.numa); + + llama_model * model; + llama_context * ctx; + llama_context * ctx_guidance = NULL; + g_model = &model; + g_ctx = &ctx; + + // load the model and apply lora adapter, if any + LOG("%s: load the model and apply lora adapter, if any\n", __func__); + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (params.cfg_scale > 1.f) { + struct llama_context_params lparams = llama_context_params_from_gpt_params(params); + ctx_guidance = llama_new_context_with_model(model, lparams); + } + + if (model == NULL) { + LOG_TEE("%s: error: unable to load model\n", __func__); + return 1; + } + + const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + LOG("n_ctx: %d\n", n_ctx); + + if (n_ctx > n_ctx_train) { + LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, n_ctx); + } + + // print system information + { + LOG_TEE("\n"); + LOG_TEE("%s\n", get_system_info(params).c_str()); + } + const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; + LOG("add_bos: %d\n", add_bos); + + std::vector embd_inp; + std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); + std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); + embd_inp = inp_pfx; + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + embd_inp.push_back(llama_token_middle(ctx)); + + LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix)); + LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix)); + LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp)); + + // Should not run without any tokens + if (embd_inp.empty()) { + embd_inp.push_back(llama_token_bos(ctx)); + LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp)); + } + + // Tokenize negative prompt + std::vector guidance_inp; + int guidance_offset = 0; + int original_prompt_len = 0; + if (ctx_guidance) { + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); + + guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); + LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); + + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); + LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp)); + + original_prompt_len = original_inp.size(); + guidance_offset = (int)guidance_inp.size() - original_prompt_len; + LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); + LOG("guidance_offset: %s", log_tostr(guidance_offset)); + } + + if ((int) embd_inp.size() > n_ctx - 4) { + LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); + return 1; + } + + // number of tokens to keep when resetting context + if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size()) { + params.n_keep = (int)embd_inp.size(); + } + + LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx)); + LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx)); + + + // enable interactive mode if interactive start is specified + if (params.interactive_first) { + params.interactive = true; + } + + if (params.verbose_prompt) { + LOG_TEE("\n"); + LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + for (int i = 0; i < (int) embd_inp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); + } + + if (ctx_guidance) { + LOG_TEE("\n"); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); + for (int i = 0; i < (int) guidance_inp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); + } + } + + if (params.n_keep > 0) { + LOG_TEE("%s: static prompt based on n_keep: '", __func__); + for (int i = 0; i < params.n_keep; i++) { + LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str()); + } + LOG_TEE("'\n"); + } + LOG_TEE("\n"); + } + + if (params.interactive) { +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + LOG_TEE("%s: interactive mode on.\n", __func__); + + if (params.input_prefix_bos) { + LOG_TEE("Input prefix with BOS\n"); + } + + if (!params.input_prefix.empty()) { + LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str()); + } + + if (!params.input_suffix.empty()) { + LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); + } + } + LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", + params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); + LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + LOG_TEE("\n\n"); + + struct llama_grammar * grammar = NULL; + grammar_parser::parse_state parsed_grammar; + + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; + } + LOG_TEE("%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + LOG_TEE("\n"); + + { + auto it = params.logit_bias.find(llama_token_eos(ctx)); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + + // TODO: replace with ring-buffer + std::vector last_tokens(n_ctx); + std::fill(last_tokens.begin(), last_tokens.end(), 0); + LOG_TEE("\n##### Infill mode #####\n\n"); + if (params.infill) { + printf("\n************\n"); + printf("no need to specify '--infill', always running infill\n"); + printf("************\n\n"); + } + if (params.interactive) { + const char *control_message; + if (params.multiline_input) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } + LOG_TEE("== Running in interactive mode. ==\n"); +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) + LOG_TEE( " - Press Ctrl+C to interject at any time.\n"); +#endif + LOG_TEE( "%s\n", control_message); + + is_interacting = params.interactive_first; + } + + bool input_echo = true; + + int n_past = 0; + int n_remain = params.n_predict; + int n_consumed = 0; + int n_past_guidance = 0; + + std::vector input_tokens; g_input_tokens = &input_tokens; + std::vector output_tokens; g_output_tokens = &output_tokens; + std::ostringstream output_ss; g_output_ss = &output_ss; + + // the first thing we will do is to output the prompt, so set color accordingly + console::set_display(console::prompt); + + std::vector embd; + std::vector embd_guidance; + + const int n_vocab = llama_n_vocab(model); + + std::vector candidates; + candidates.reserve(n_vocab); + + while (n_remain != 0 || params.interactive) { + // predict + if (!embd.empty()) { + // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via + // --prompt or --file which uses the same value. + int max_embd_size = n_ctx - 4; + + // Ensure the input doesn't exceed the context size by truncating embd if necessary. + if ((int) embd.size() > max_embd_size) { + const int skipped_tokens = (int) embd.size() - max_embd_size; + embd.resize(max_embd_size); + + console::set_display(console::error); + printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + console::set_display(console::reset); + fflush(stdout); + } + + // infinite text generation via context swapping + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + n_past -= n_discard; + + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); + + } + + // evaluate tokens in batches + // embd is typically prepared beforehand to fit within a batch, but not always + + if (ctx_guidance) { + int input_size = 0; + llama_token * input_buf = NULL; + + if (n_past_guidance < (int) guidance_inp.size()) { + // Guidance context should have the same data with these modifications: + // + // * Replace the initial prompt + // * Shift everything by guidance_offset + embd_guidance = guidance_inp; + if (embd.begin() + original_prompt_len < embd.end()) { + embd_guidance.insert( + embd_guidance.end(), + embd.begin() + original_prompt_len, + embd.end() + ); + } + + input_buf = embd_guidance.data(); + input_size = embd_guidance.size(); + + LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance)); + } else { + input_buf = embd.data(); + input_size = embd.size(); + } + + for (int i = 0; i < input_size; i += params.n_batch) { + int n_eval = std::min(input_size - i, params.n_batch); + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + n_past_guidance += n_eval; + } + } + + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { + int n_eval = (int) embd.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + + LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); + + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + n_past += n_eval; + + LOG("n_past = %d\n", n_past); + } + + } + + embd.clear(); + embd_guidance.clear(); + + if ((int) embd_inp.size() <= n_consumed && !is_interacting) { + + const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); + + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(id); + + LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens)); + + embd.push_back(id); + + // echo this to console + input_echo = true; + + // decrement remaining sampling budget + --n_remain; + + LOG("n_remain: %d\n", n_remain); + } else { + // some user input remains from prompt or interaction, forward it to processing + LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); + while ((int) embd_inp.size() > n_consumed) { + embd.push_back(embd_inp[n_consumed]); + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(embd_inp[n_consumed]); + ++n_consumed; + if ((int) embd.size() >= params.n_batch) { + break; + } + } + } + + // display text + if (input_echo) { + for (auto id : embd) { + const std::string token_str = llama_token_to_piece(ctx, id); + printf("%s", token_str.c_str()); + + if (embd.size() > 1) { + input_tokens.push_back(id); + } else { + output_tokens.push_back(id); + output_ss << token_str; + } + } + fflush(stdout); + } + // reset color to default if we there is no pending user input + if (input_echo && (int) embd_inp.size() == n_consumed) { + console::set_display(console::reset); + } + + // if not currently processing queued inputs; + if ((int) embd_inp.size() <= n_consumed) { + + // deal with eot token in infill mode + if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){ + if(is_interacting && !params.interactive_first) { + // print an eot token + printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); + } + fflush(stdout); + printf("\n"); + console::set_display(console::user_input); + std::string buffer; + std::string line; + bool another_line=true; + // set a new prefix via stdin + do { + another_line = console::readline(line, params.multiline_input); + buffer += line; + } while (another_line); + // check if we got an empty line, if so we use the old input + if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + params.input_prefix = buffer; + } + buffer.clear(); + // set a new suffix via stdin + do { + another_line = console::readline(line, params.multiline_input); + buffer += line; + } while (another_line); + // check if we got an empty line + if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + params.input_suffix = buffer; + } + buffer.clear(); + // done taking input, reset color + console::set_display(console::reset); + // tokenize new prefix and suffix + std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); + std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); + embd_inp = inp_pfx; + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + embd_inp.push_back(llama_token_middle(ctx)); + embd.clear(); + embd_guidance.clear(); + n_remain = params.n_predict; + n_past = 0; + n_consumed = 0; + // LOG_TEE("took new input\n"); + is_interacting = false; + } + // deal with end of text token in interactive mode + else if (last_tokens.back() == llama_token_eos(ctx)) { + LOG("found EOS token\n"); + + if (params.interactive) { + + is_interacting = true; + printf("\n"); + console::set_display(console::user_input); + fflush(stdout); + } + } + + if (n_past > 0 && is_interacting && !params.interactive) { + LOG("waiting for user input\n"); + + if (params.input_prefix_bos) { + LOG("adding input prefix BOS token\n"); + embd_inp.push_back(llama_token_bos(ctx)); + } + + std::string buffer; + if (!params.input_prefix.empty()) { + LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); + buffer += params.input_prefix; + printf("%s", buffer.c_str()); + } + + std::string line; + bool another_line = true; + do { + another_line = console::readline(line, params.multiline_input); + buffer += line; + } while (another_line); + + // done taking input, reset color + console::set_display(console::reset); + + // Add tokens to embd only if the input buffer is non-empty + // Entering a empty line lets the user pass control back + if (buffer.length() > 1) { + // append input suffix if any + if (!params.input_suffix.empty()) { + LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); + buffer += params.input_suffix; + printf("%s", params.input_suffix.c_str()); + } + + LOG("buffer: '%s'\n", buffer.c_str()); + + const size_t original_size = embd_inp.size(); + + const auto line_inp = ::llama_tokenize(ctx, buffer, false); + LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp)); + + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + + for (size_t i = original_size; i < embd_inp.size(); ++i) { + const llama_token token = embd_inp[i]; + output_tokens.push_back(token); + output_ss << llama_token_to_piece(ctx, token); + } + + n_remain -= line_inp.size(); + LOG("n_remain: %d\n", n_remain); + } else { + LOG("empty line, passing control back\n"); + } + + input_echo = false; // do not echo this again + } + + if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (grammar != NULL) { + llama_grammar_free(grammar); + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); + } + } + is_interacting = false; + } + } + + // end of text token + if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !params.interactive) { + break; + } + + // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. + // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size). + if (params.interactive && n_remain <= 0 && params.n_predict >= 0) { + n_remain = params.n_predict; + is_interacting = true; + } + } + if (!params.interactive && n_remain <= 0) { + printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); + fflush(stdout); + } + + llama_print_timings(ctx); + write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + + if (ctx_guidance) { llama_free(ctx_guidance); } + llama_free(ctx); + llama_free_model(model); + + if (grammar != NULL) { + llama_grammar_free(grammar); + } + llama_backend_free(); + +#ifndef LOG_DISABLE_LOGS + LOG_TEE("Log end\n"); +#endif // LOG_DISABLE_LOGS + + return 0; +} + diff --git a/examples/server/README.md b/examples/server/README.md index d409e8408..9ee62d06a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -176,6 +176,16 @@ node index.js `content`: Set the text to process. + **POST** `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. + + *Options:* + + `input_prefix`: Set the prefix of the code to infill. + + `input_suffix`: Set the suffix of the code to infill. + + It also accepts all the options of `/completion` except `stream` and `prompt`. + ## More examples ### Interactive mode diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fe9a4255e..6dda5e36b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -342,6 +342,70 @@ struct llama_server_context return true; } + void loadInfill() + { + auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS + auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(ctx)); + auto prompt_tokens = prefix_tokens; + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) + { + params.n_keep = (int)num_prompt_tokens; + } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t)params.n_ctx) + { + printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); + // todo we probably want to cut from both sides + const int n_left = (params.n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + LOG_VERBOSE("input truncated", { + {"n_ctx", params.n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + }); + + truncated = true; + prompt_tokens = new_tokens; + } + else + { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + embd = prompt_tokens; + if (n_past == num_prompt_tokens) + { + // we have to evaluate at least 1 token to generate logits. + printf("we have to evaluate at least 1 token to generate logits\n"); + n_past--; + } + + LOG_VERBOSE("prompt ingested", { + {"n_past", n_past}, + {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, + {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, + }); + + has_next_token = true; + } void loadPrompt() { auto prompt_tokens = tokenize(prompt, true); // always add BOS @@ -1219,6 +1283,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } +static void parse_options_infill(const json &body, llama_server_context &llama) +{ + if (body.count("input_prefix") != 0) + { + llama.params.input_prefix = body["input_prefix"]; + } + else + { + llama.params.input_prefix = ""; + } + if (body.count("input_suffix") != 0) + { + llama.params.input_suffix = body["input_suffix"]; + } + else + { + llama.params.input_suffix = ""; + } + parse_options_completion(body, llama); +} + static void log_server_request(const Request &req, const Response &res) { LOG_INFO("request", { @@ -1519,6 +1604,127 @@ int main(int argc, char **argv) res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); + svr.Post("/infill", [&llama](const Request &req, Response &res) + { + auto lock = llama.lock(); + + llama.rewind(); + + llama_reset_timings(llama.ctx); + + parse_options_infill(json::parse(req.body), llama); + + if (!llama.loadGrammar()) + { + res.status = 400; + return; + } + llama.loadInfill(); + llama.beginCompletion(); + const auto chunked_content_provider = [&](size_t, DataSink & sink) { + size_t sent_count = 0; + size_t sent_token_probs_index = 0; + + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { + continue; + } + const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + + size_t pos = std::min(sent_count, llama.generated_text.size()); + + const std::string str_test = llama.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + is_stop_full = true; + llama.generated_text.erase( + llama.generated_text.begin() + pos + stop_pos, + llama.generated_text.end()); + pos = std::min(sent_count, llama.generated_text.size()); + } else { + is_stop_full = false; + stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + STOP_PARTIAL); + } + + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama.has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + + std::vector probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + + const json data = format_partial_response(llama, to_send, probs_output); + + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; + } + } + + if (!llama.has_next_token) { + // Generation is done, send extra information. + const json data = format_final_response( + llama, + "", + std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) + ); + + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; + } + } + } + + llama_print_timings(llama.ctx); + sink.done(); + return true; + }; + const auto on_complete = [&](bool) { + llama.mutex.unlock(); + }; + lock.release(); + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + }); + svr.Get("/model.json", [&llama](const Request &, Response &res) { const json data = format_generation_settings(llama); diff --git a/llama.cpp b/llama.cpp index bff17135b..3a0b2c308 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1076,6 +1076,10 @@ struct llama_vocab { id special_pad_id = -1; id linefeed_id = 13; + id special_prefix_id = 32007; + id special_middle_id = 32009; + id special_suffix_id = 32008; + id special_eot_id = 32010; int find_bpe_rank(std::string token_left, std::string token_right) const { replace_all(token_left, " ", "\u0120"); @@ -7489,6 +7493,22 @@ llama_token llama_token_eos(const struct llama_context * ctx) { llama_token llama_token_nl(const struct llama_context * ctx) { return ctx->model.vocab.linefeed_id; } +llama_token llama_token_prefix(const struct llama_context * ctx) { + return ctx->model.vocab.special_prefix_id; +} + +llama_token llama_token_middle(const struct llama_context * ctx) { + return ctx->model.vocab.special_middle_id; +} + +llama_token llama_token_suffix(const struct llama_context * ctx) { + return ctx->model.vocab.special_suffix_id; +} + +llama_token llama_token_eot(const struct llama_context * ctx) { + return ctx->model.vocab.special_eot_id; +} + int llama_tokenize( const struct llama_model * model, diff --git a/llama.h b/llama.h index fde4d6eca..fd2158400 100644 --- a/llama.h +++ b/llama.h @@ -490,6 +490,11 @@ extern "C" { LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line + // codellama infill tokens + LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix + LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle + LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix + LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of infill middle // // Tokenization From ea55295a745c084f588be20710f5a1a12abb1109 Mon Sep 17 00:00:00 2001 From: Kevin Ji <1146876+kevinji@users.noreply.github.com> Date: Mon, 2 Oct 2023 04:53:53 -0400 Subject: [PATCH 08/12] docker : ignore Git files (#3314) --- .dockerignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.dockerignore b/.dockerignore index c6ef6c86c..633bbc3a9 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,9 @@ *.o *.a .cache/ +.git/ +.github/ +.gitignore .vs/ .vscode/ .DS_Store From 095231dfd32679e32300f8ffaf1770b693ea64b0 Mon Sep 17 00:00:00 2001 From: bandoti <141645996+bandoti@users.noreply.github.com> Date: Mon, 2 Oct 2023 06:51:49 -0300 Subject: [PATCH 09/12] cmake : fix transient definitions in find pkg (#3411) --- CMakeLists.txt | 1 + examples/main-cmake-pkg/CMakeLists.txt | 10 ++++++++++ scripts/LlamaConfig.cmake.in | 2 ++ 3 files changed, 13 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index d5acf8540..17d705422 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -705,6 +705,7 @@ set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) +get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS) configure_package_config_file( ${CMAKE_CURRENT_SOURCE_DIR}/scripts/LlamaConfig.cmake.in diff --git a/examples/main-cmake-pkg/CMakeLists.txt b/examples/main-cmake-pkg/CMakeLists.txt index 473738719..908131884 100644 --- a/examples/main-cmake-pkg/CMakeLists.txt +++ b/examples/main-cmake-pkg/CMakeLists.txt @@ -28,6 +28,16 @@ configure_file(${_common_path}/../build-info.h target_include_directories(common PUBLIC ${LLAMA_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) +# If the common project was part of "main-cmake-pkg" the transient +# defines would automatically be attached. Because the common func- +# tionality is separate, but dependent upon the defines, it must be +# explicitly extracted from the "llama" target. +# +get_target_property(_llama_transient_defines llama + INTERFACE_COMPILE_DEFINITIONS) + +target_compile_definitions(common PRIVATE "${_llama_transient_defines}") + add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../main/main.cpp) target_include_directories(${TARGET} PRIVATE ${_common_path}) install(TARGETS ${TARGET} RUNTIME) diff --git a/scripts/LlamaConfig.cmake.in b/scripts/LlamaConfig.cmake.in index e1fadc361..6a6d8e39e 100644 --- a/scripts/LlamaConfig.cmake.in +++ b/scripts/LlamaConfig.cmake.in @@ -56,11 +56,13 @@ find_library(llama_LIBRARY llama HINTS ${LLAMA_LIB_DIR}) set(_llama_link_deps "Threads::Threads" "@LLAMA_EXTRA_LIBS@") +set(_llama_transient_defines "@LLAMA_TRANSIENT_DEFINES@") add_library(llama UNKNOWN IMPORTED) set_target_properties(llama PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}" INTERFACE_LINK_LIBRARIES "${_llama_link_deps}" + INTERFACE_COMPILE_DEFINITIONS "${_llama_transient_defines}" IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" IMPORTED_LOCATION "${llama_LIBRARY}" INTERFACE_COMPILE_FEATURES cxx_std_11 From a84767698495d72e44044f1f6db1c1cc721bfd15 Mon Sep 17 00:00:00 2001 From: Adrian Date: Mon, 2 Oct 2023 03:49:59 -0700 Subject: [PATCH 10/12] metal : set log callback before initializing (#3427) --- llama.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 3a0b2c308..05b570bd1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6760,13 +6760,14 @@ struct llama_context * llama_new_context_with_model( #ifdef GGML_USE_METAL if (model->n_gpu_layers > 0) { + ggml_metal_log_set_callback(llama_log_callback_default, NULL); + ctx->ctx_metal = ggml_metal_init(1); if (!ctx->ctx_metal) { LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); llama_free(ctx); return NULL; } - ggml_metal_log_set_callback(llama_log_callback_default, NULL); //ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); } From a03ce38455544121c5c00cf845def1443acd6ac8 Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 2 Oct 2023 15:15:45 +0200 Subject: [PATCH 11/12] finetune : fix #3404 (#3437) the shapes for init model of gqa models was wrong --- examples/finetune/finetune.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 8ca1874da..9ae4bc198 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -332,8 +332,8 @@ static void init_model(struct llama_model * input, struct my_llama_model * model assert_shape_1d(layer.attention_norm, hparams.n_embd); assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd); - assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd); - assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd); + assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd_gqa()); + assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd_gqa()); assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd); assert_shape_1d(layer.ffn_norm, hparams.n_embd); assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff); From 9476b012260a2fb6c67976582d64484ce7406ed9 Mon Sep 17 00:00:00 2001 From: cebtenzzre Date: Mon, 2 Oct 2023 09:16:50 -0400 Subject: [PATCH 12/12] cmake : make CUDA flags more similar to the Makefile (#3420) * cmake : fix misuse of cxx_flags * cmake : make CUDA flags more similar to the Makefile * cmake : fix MSVC build --- CMakeLists.txt | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 17d705422..221fc07a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -343,8 +343,9 @@ if (LLAMA_MPI) set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) add_compile_definitions(GGML_USE_MPI) add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) - set(cxx_flags ${cxx_flags} -Wno-cast-qual) - set(c_flags ${c_flags} -Wno-cast-qual) + if (NOT MSVC) + add_compile_options(-Wno-cast-qual) + endif() set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) # Even if you're only using the C header, C++ programs may bring in MPI @@ -418,10 +419,11 @@ if (LLAMA_ALL_WARNINGS) set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int -Werror=implicit-function-declaration) set(cxx_flags -Wmissing-declarations -Wmissing-noreturn) + set(host_cxx_flags "") if (CMAKE_C_COMPILER_ID MATCHES "Clang") set(warning_flags ${warning_flags} -Wunreachable-code-break -Wunreachable-code-return) - set(cxx_flags ${cxx_flags} -Wmissing-prototypes -Wextra-semi) + set(host_cxx_flags ${host_cxx_flags} -Wmissing-prototypes -Wextra-semi) if ( (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 3.8.0) OR @@ -431,27 +433,38 @@ if (LLAMA_ALL_WARNINGS) endif() elseif (CMAKE_C_COMPILER_ID STREQUAL "GNU") set(c_flags ${c_flags} -Wdouble-promotion) - set(cxx_flags ${cxx_flags} -Wno-array-bounds) + set(host_cxx_flags ${host_cxx_flags} -Wno-array-bounds) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 7.1.0) - set(cxx_flags ${cxx_flags} -Wno-format-truncation) + set(host_cxx_flags ${host_cxx_flags} -Wno-format-truncation) endif() if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1.0) - set(cxx_flags ${cxx_flags} -Wextra-semi) + set(host_cxx_flags ${host_cxx_flags} -Wextra-semi) endif() endif() else() # todo : msvc endif() - add_compile_options( - ${warning_flags} - "$<$:${c_flags}>" - "$<$:${cxx_flags}>" - ) + set(c_flags ${c_flags} ${warning_flags}) + set(cxx_flags ${cxx_flags} ${warning_flags}) + add_compile_options("$<$:${c_flags}>" + "$<$:${cxx_flags} ${host_cxx_flags}>") endif() +if (NOT MSVC) + set(cuda_flags -Wno-pedantic) +endif() +set(cuda_flags ${cxx_flags} -use_fast_math ${cuda_flags}) + +list(JOIN host_cxx_flags " " cuda_host_flags) # pass host compiler flags as a single argument +if (NOT cuda_host_flags STREQUAL "") + set(cuda_flags ${cuda_flags} -Xcompiler ${cuda_host_flags}) +endif() + +add_compile_options("$<$:${cuda_flags}>") + if (WIN32) add_compile_definitions(_CRT_SECURE_NO_WARNINGS)