Merge branch 'ggerganov:master' into bitnet

This commit is contained in:
Eddie-Wang 2024-06-11 10:50:12 +08:00 committed by GitHub
commit 2322e9db9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 699 additions and 263 deletions

View File

@ -16,11 +16,9 @@ on:
branches:
- master
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
pull_request_target:
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
schedule:
- cron: '2 4 * * *'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}
@ -115,7 +113,7 @@ jobs:
server-windows:
runs-on: windows-latest
runs-on: windows-2019
steps:
- name: Clone

View File

@ -402,12 +402,26 @@ if (LLAMA_CUBLAS)
endif()
if (LLAMA_CUDA)
cmake_minimum_required(VERSION 3.17)
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "CUDA found")
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics
# 61 == integer CUDA intrinsics
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
else()
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
#set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
endif()
endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
enable_language(CUDA)
set(GGML_HEADERS_CUDA ggml-cuda.h)
@ -472,21 +486,6 @@ if (LLAMA_CUDA)
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
endif()
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics
# 61 == integer CUDA intrinsics
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
else()
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
#set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
endif()
endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
else()
message(WARNING "CUDA not found")
endif()

View File

@ -53,7 +53,6 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
<li><a href="#quantization">Quantization</a></li>
<li><a href="#interactive-mode">Interactive mode</a></li>
<li><a href="#constrained-output-with-grammars">Constrained output with grammars</a></li>
<li><a href="#instruct-mode">Instruct mode</a></li>
<li><a href="#obtaining-and-using-the-facebook-llama-2-model">Obtaining and using the Facebook LLaMA 2 model</a></li>
<li><a href="#seminal-papers-and-background-on-the-models">Seminal papers and background on the models</a></li>
<li><a href="#perplexity-measuring-model-quality">Perplexity (measuring model quality)</a></li>
@ -769,34 +768,6 @@ The `grammars/` folder contains a handful of sample grammars. To write your own,
For authoring more complex JSON grammars, you can also check out https://grammar.intrinsiclabs.ai/, a browser app that lets you write TypeScript interfaces which it compiles to GBNF grammars that you can save for local use. Note that the app is built and maintained by members of the community, please file any issues or FRs on [its repo](http://github.com/intrinsiclabsai/gbnfgen) and not this one.
### Instruct mode
1. First, download and place the `ggml` model into the `./models` folder
2. Run the `main` tool like this:
```
./examples/alpaca.sh
```
Sample run:
```
== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMA.
- If you want to submit another line, end your input in '\'.
Below is an instruction that describes a task. Write a response that appropriately completes the request.
> How many letters are there in the English alphabet?
There 26 letters in the English Alphabet
> What is the most common way of transportation in Amsterdam?
The majority (54%) are using public transit. This includes buses, trams and metros with over 100 lines throughout the city which make it very accessible for tourists to navigate around town as well as locals who commute by tram or metro on a daily basis
> List 5 words that start with "ca".
cadaver, cauliflower, cabbage (vegetable), catalpa (tree) and Cailleach.
>
```
### Obtaining and using the Facebook LLaMA 2 model
- Refer to [Facebook's LLaMA download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) if you want to access the model data.

View File

@ -40,7 +40,7 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
const std::string SPACE_RULE = "\" \"?";
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
struct BuiltinRule {
std::string content;
@ -57,7 +57,7 @@ std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
{"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
{"null", {"\"null\" space", {}}},
};

View File

@ -1,19 +0,0 @@
#!/bin/bash
#
# Temporary script - will be removed in the future
#
cd `dirname $0`
cd ..
./main -m ./models/alpaca.13b.ggmlv3.q8_0.bin \
--color \
-f ./prompts/alpaca.txt \
--ctx_size 2048 \
-n -1 \
-ins -b 256 \
--top_k 10000 \
--temp 0.2 \
--repeat_penalty 1.1 \
-t 7

View File

@ -1,15 +0,0 @@
#!/bin/bash
#
# Temporary script - will be removed in the future
#
cd `dirname $0`
cd ..
./main --color --instruct --threads 4 \
--model ./models/gpt4all-7B/gpt4all-lora-quantized.bin \
--file ./prompts/alpaca.txt \
--batch_size 8 --ctx_size 2048 -n -1 \
--repeat_last_n 64 --repeat_penalty 1.3 \
--n_predict 128 --temp 0.1 --top_k 40 --top_p 0.95

View File

@ -29,9 +29,8 @@ class BuiltinRule:
self.content = content
self.deps = deps or []
# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
# Constraining spaces to prevent model "running away".
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
@ -43,7 +42,7 @@ PRIMITIVE_RULES = {
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})', []),
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
'null' : BuiltinRule('"null" space', []),
}

View File

@ -1,18 +0,0 @@
#!/bin/bash
#
# Temporary script - will be removed in the future
#
cd `dirname $0`
cd ..
./main -m models/available/Llama2/13B/llama-2-13b.ggmlv3.q4_0.bin \
--color \
--ctx_size 2048 \
-n -1 \
-ins -b 256 \
--top_k 10000 \
--temp 0.2 \
--repeat_penalty 1.1 \
-t 8

View File

@ -1,18 +0,0 @@
#!/bin/bash
#
# Temporary script - will be removed in the future
#
cd `dirname $0`
cd ..
./main -m models/available/Llama2/7B/llama-2-7b.ggmlv3.q4_0.bin \
--color \
--ctx_size 2048 \
-n -1 \
-ins -b 256 \
--top_k 10000 \
--temp 0.2 \
--repeat_penalty 1.1 \
-t 8

View File

@ -1,5 +1,5 @@
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '" "?';
const SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}';
function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
if (minItems === 0 && maxItems === 1) {
@ -41,7 +41,7 @@ const PRIMITIVE_RULES = {
object : new BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
array : new BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
uuid : new BuiltinRule('"\\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\\"" space', []),
char : new BuiltinRule(`[^"\\\\] | "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F]{4})`, []),
char : new BuiltinRule(`[^"\\\\\\x7F\\x00-\\x1F] | [\\\\] (["\\\\bfnrt] | "u" [0-9a-fA-F]{4})`, []),
string : new BuiltinRule(`"\\"" char* "\\"" space`, ['char']),
null : new BuiltinRule('"null" space', []),
};

View File

@ -147,7 +147,7 @@ struct server_slot {
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;
json prompt;
std::string prompt;
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
@ -822,13 +822,8 @@ struct server_context {
continue;
}
// skip the slot if it does not contains prompt
if (!slot.prompt.is_string()) {
continue;
}
// current slot's prompt
std::string slot_prompt = slot.prompt.get<std::string>();
std::string slot_prompt = slot.prompt;
// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
@ -958,13 +953,16 @@ struct server_context {
if (!task.infill) {
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
} else {
slot.prompt = *prompt;
}
if (slot.prompt.is_array() && slot.prompt.size() == 0) {
send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
if (prompt->is_string()) {
slot.prompt = prompt->get<std::string>();
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
slot.prompt = prompt->at(0).get<std::string>();
} else {
send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
@ -1582,14 +1580,18 @@ struct server_context {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
{
int id_slot = json_value(task.data, "id_slot", -1);
std::string prompt = json_value(task.data, "prompt", std::string());
const int id_slot = json_value(task.data, "id_slot", -1);
server_slot * slot;
if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
} else {
std::string prompt;
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
json_value(task.data, "prompt", std::string());
}
slot = get_available_slot(prompt);
}

View File

@ -139,6 +139,7 @@
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_TURING 750
#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
#endif // defined(GGML_USE_HIPBLAS)
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
static bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
}
static bool int8_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
}
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
}
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
return __float2half(fmaxf(__half2float(a), __half2float(b)));

View File

@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_h2 = (const half2 *) Q_v;
@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F);
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh) - 16;
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh);
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
const T d = x[ib].d;
const int q = x[ib].qs[iqs];
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}

View File

@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.

View File

@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);

View File

@ -1,9 +1,9 @@
#include "common.cuh"
#include "fattn-common.cuh"
#if FP16_MMA_AVAILABLE
#ifdef FP16_MMA_AVAILABLE
#include <mma.h>
#endif
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_MMA_AVAILABLE
#ifdef FP16_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.

95
ggml-cuda/mma.cuh Normal file
View File

@ -0,0 +1,95 @@
#include "common.cuh"
struct mma_int_A_I16K8 {
static constexpr int I = 16;
static constexpr int K = 8;
static constexpr int ne = 4;
int x[ne] = {0};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_k(const int l) {
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};
struct mma_int_B_J8K8 {
static constexpr int J = 8;
static constexpr int K = 8;
static constexpr int ne = 2;
int x[ne] = {0};
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
static __device__ __forceinline__ int get_k(const int l) {
const int ret = l * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};
struct mma_int_C_I16J8 {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
int x[ne] = {0};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
#else
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
};

View File

@ -2,6 +2,7 @@
#include "common.cuh"
#include "vecdotq.cuh"
#include "mma.cuh"
#include <climits>
#include <cstdint>
@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)(
typedef void (*vec_dot_mmq_t)(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
struct block_q8_1_mmq {
half2 ds[4];
@ -141,15 +143,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
const float * x_dmf = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const float * x_df = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
}
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
(&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
(&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
}
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
const float * x_df = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
mma_A A;
float dA[mma_C::ne/2];
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = k0 + mma_A::get_k(l) % QI4_0;
const int shift = 4*(mma_A::get_k(l) / QI4_0);
A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
}
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
mma_C C;
mma_B B;
half2 dsB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
C.mma_K8(A, B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@ -215,7 +281,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
mma_A A;
half2 dmA[mma_C::ne/2];
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = k0 + mma_A::get_k(l) % QI4_0;
const int shift = 4*(mma_A::get_k(l) / QI4_0);
A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
}
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
mma_C C;
mma_B B;
half2 dsB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
C.mma_K8(A, B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@ -308,7 +438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
const float * x_df = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
mma_A A;
float dA[mma_C::ne/2];
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
}
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
mma_C C;
mma_B B;
float dB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
C.mma_K8(A, B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
@ -400,7 +592,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
mma_A A;
half2 dmA[mma_C::ne/2];
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
}
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
mma_C C;
mma_B B;
half2 dsB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
}
C.mma_K8(A, B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@ -475,7 +730,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
const float * x_df = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
mma_A A;
float dA[mma_C::ne/2];
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = k0 + mma_A::get_k(l);
A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
}
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
mma_C C;
mma_B B;
float dB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = k0 + mma_B::get_k(l);
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
}
C.mma_K8(A, B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
}
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
if (j >= ne1) {
return;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
if (need_check && i >= ne0) {
continue;
}
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
typedef mma_int_C_I16J8 mma_C;
const int i0 = threadIdx.y*mma_C::I;
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
if (j >= ne1) {
continue;
}
const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
if (need_check && i >= ne0) {
continue;
}
dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
}
}
}
// -------------------------------------------------------------------------------------------------------------------------------------
template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
@ -998,35 +1367,65 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@ -1034,6 +1433,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@ -1041,6 +1441,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@ -1048,6 +1449,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@ -1055,6 +1457,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@ -1062,6 +1465,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
};
static int mmq_need_sum(const ggml_type type_x) {
@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q(
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q(
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q(
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
if (j >= ne1) {
return;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
if (need_check && i >= ne0) {
continue;
}
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
write_back(sum, dst, ne0, ne1);
}
struct mmq_args {
@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
launch_mul_mat_q<type, 8, 4>(args, stream);
break;
case 16:
launch_mul_mat_q<type, 16, 8>(args, stream);
launch_mul_mat_q<type, 16, 4>(args, stream);
break;
case 24:
launch_mul_mat_q<type, 24, 8>(args, stream);
launch_mul_mat_q<type, 24, 4>(args, stream);
break;
case 32:
launch_mul_mat_q<type, 32, 8>(args, stream);

View File

@ -13089,10 +13089,12 @@ void *ggml_sycl_host_malloc(size_t size) try {
return nullptr;
}
ggml_sycl_set_device(g_main_device);
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
void * ptr = nullptr;
//allow to use dpct::get_in_order_queue() for host malloc
dpct::err0 err = CHECK_TRY_ERROR(
ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue()));
ptr = (void *)sycl::malloc_host(size, *main_stream));
if (err != 0) {
// clear the error
@ -13113,8 +13115,9 @@ catch (sycl::exception const &exc) {
}
void ggml_sycl_host_free(void *ptr) try {
//allow to use dpct::get_in_order_queue() for host malloc
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
ggml_sycl_set_device(g_main_device);
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream)));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__

View File

@ -94,6 +94,8 @@ This guide provides a brief overview. Check out the GBNF files in this directory
./main -m <model> --grammar-file grammars/some-grammar.gbnf -p 'Some prompt'
```
`llama.cpp` can also convert JSON schemas to grammars either ahead of time or at each request, see below.
## Troubleshooting
Grammars currently have performance gotchas (see https://github.com/ggerganov/llama.cpp/issues/4218).
@ -103,3 +105,40 @@ Grammars currently have performance gotchas (see https://github.com/ggerganov/ll
A common pattern is to allow repetitions of a pattern `x` up to N times.
While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) may result in extremely slow sampling. Instead, you can write `x{0,N}` (or `(x (x (x ... (x)?...)?)?)?` w/ N-deep nesting in earlier llama.cpp versions).
## Using GBNF grammars
You can use GBNF grammars:
- In the [server](../examples/server)'s completion endpoints, passed as the `grammar` body field
- In the [main](../examples/main) CLI, passed as the `--grammar` & `--grammar-file` flags
- With the [gbnf-validator](../examples/gbnf-validator) tool, to test them against strings.
## JSON Schemas → GBNF
`llama.cpp` supports converting a subset of https://json-schema.org/ to GBNF grammars:
- In the [server](../examples/server):
- For any completion endpoints, passed as the `json_schema` body field
- For the `/chat/completions` endpoint, passed inside the `result_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}`)
- In the [main](../examples/main) CLI, passed as the `--json` / `-j` flag
- To convert to a grammar ahead of time:
- in CLI, with [json_schema_to_grammar.py](../examples/json_schema_to_grammar.py)
- in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI)
Take a look at [tests](../../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggerganov/llama.cpp/pull/5978, https://github.com/ggerganov/llama.cpp/pull/6659 & https://github.com/ggerganov/llama.cpp/pull/6555).
Here is also a non-exhaustive list of **unsupported** features:
- `additionalProperties`: to be fixed in https://github.com/ggerganov/llama.cpp/pull/7840
- `minimum`, `exclusiveMinimum`, `maximum`, `exclusiveMaximum`
- `integer` constraints to be implemented in https://github.com/ggerganov/llama.cpp/pull/7797
- Remote `$ref`s in the C++ version (Python & JavaScript versions fetch https refs)
- Mixing `properties` w/ `anyOf` / `oneOf` in the same type (https://github.com/ggerganov/llama.cpp/issues/7703)
- `string` formats `uri`, `email`
- [`contains`](https://json-schema.org/draft/2020-12/json-schema-core#name-contains) / `minContains`
- `uniqueItems`
- `$anchor` (cf. [dereferencing](https://json-schema.org/draft/2020-12/json-schema-core#name-dereferencing))
- [`not`](https://json-schema.org/draft/2020-12/json-schema-core#name-not)
- [Conditionals](https://json-schema.org/draft/2020-12/json-schema-core#name-keywords-for-applying-subsche) `if` / `then` / `else` / `dependentSchemas`
- [`patternProperties`](https://json-schema.org/draft/2020-12/json-schema-core#name-patternproperties)

View File

@ -16,10 +16,10 @@ array ::=
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
ws ::= | " " | "\n" [ \t]{0,20}

View File

@ -25,10 +25,10 @@ array ::=
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
ws ::= | " " | "\n" [ \t]{0,20}

View File

@ -105,14 +105,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null
)"""
@ -135,7 +135,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
date-time ::= date "T" time
date-time-string ::= "\"" date-time "\"" space
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
time-string ::= "\"" time "\"" space
tuple-0 ::= date-string
@ -152,9 +152,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"type": "string"
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char* "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -166,9 +166,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minLength": 1
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char+ "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -180,9 +180,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minLength": 3
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{3,} "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -194,9 +194,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maxLength": 3
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{0,3} "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -209,9 +209,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maxLength": 4
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{1,4} "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -223,7 +223,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= ("true" | "false") space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -236,7 +236,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= ("-"? integral-part) space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -248,7 +248,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"foo\""
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -260,7 +260,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "123"
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -272,7 +272,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -283,9 +283,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"prefixItems": [{ "type": "string" }]
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "[" space string "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -297,12 +297,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"prefixItems": [{ "type": "string" }, { "type": "number" }]
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "[" space string "," space number "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -317,7 +317,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -333,7 +333,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean ("," space boolean)+ "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -349,7 +349,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean? "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -365,7 +365,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space (boolean ("," space boolean)?)? "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -386,7 +386,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
item ::= number | integer
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "[" space item ("," space item){2,4} "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -399,7 +399,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -412,7 +412,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"" "[]{}()|+*?" "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -425,7 +425,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"" "\"" "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -440,7 +440,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
dot ::= [^\x0A\x0D]
root ::= "\"" ("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot "\"" space
root-1 ::= [0-9]
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -466,9 +466,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -486,9 +486,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space (a-kv )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -510,9 +510,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
b-kv ::= "\"b\"" space ":" space string
b-rest ::= ( "," space c-kv )?
c-kv ::= "\"c\"" space ":" space string
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -534,11 +534,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
d-kv ::= "\"d\"" space ":" space string
d-rest ::= ( "," space c-kv )?
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -554,12 +554,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
additional-kv ::= string ":" space additional-value
additional-kvs ::= additional-kv ( "," space additional-kv )*
additional-value ::= "[" space (number ("," space number)*)? "]" space
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (additional-kvs )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -574,14 +574,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null
)"""
@ -596,14 +596,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null
)"""
@ -618,7 +618,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "{" space "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -637,12 +637,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-kv ::= "\"a\"" space ":" space number
additional-kv ::= string ":" space string
additional-kvs ::= additional-kv ( "," space additional-kv )*
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv ( "," space ( additional-kvs ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -662,12 +662,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-rest ::= additional-kvs
additional-kv ::= string ":" space number
additional-kvs ::= additional-kv ( "," space additional-kv )*
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -690,12 +690,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
additional-kvs ::= additional-kv ( "," space additional-kv )*
b-kv ::= "\"b\"" space ":" space number
b-rest ::= additional-kvs
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -721,11 +721,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
}
})""",
R"""(
char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
foo ::= "{" space foo-a-kv "}" space
foo-a-kv ::= "\"a\"" space ":" space string
root ::= foo
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
@ -759,7 +759,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= alternative-0 | alternative-1
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -803,7 +803,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@ -851,7 +851,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
number-number-kv ::= "\"number\"" space ":" space number-number
number-number-root-kv ::= "\"root\"" space ":" space number
root ::= "{" space number-kv "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
}