From d4d915d351d1f1270d56184bdd46672893e8a5d8 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 8 Jun 2024 20:21:08 +0100 Subject: [PATCH 01/15] url: save -mu downloads to new cache location (#7826) * url: save -mu download to new cache location * url: fs_get_cache_file_path util * url: tweak sig of fs_get_cache_file --- common/common.cpp | 20 ++++++++++++-------- common/common.h | 1 + 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d2a8bb69e..1591790e6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -200,19 +200,13 @@ void gpt_params_handle_model_default(gpt_params & params) { } params.hf_file = params.model; } else if (params.model.empty()) { - std::string cache_directory = fs_get_cache_directory(); - const bool success = fs_create_directory_with_parents(cache_directory); - if (!success) { - throw std::runtime_error("failed to create cache directory: " + cache_directory); - } - params.model = cache_directory + string_split(params.hf_file, '/').back(); + params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); } } else if (!params.model_url.empty()) { if (params.model.empty()) { auto f = string_split(params.model_url, '#').front(); f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; + params.model = fs_get_cache_file(string_split(f, '/').back()); } } else if (params.model.empty()) { params.model = DEFAULT_MODEL_PATH; @@ -2279,6 +2273,16 @@ std::string fs_get_cache_directory() { return ensure_trailing_slash(cache_directory); } +std::string fs_get_cache_file(const std::string & filename) { + GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos); + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + return cache_directory + filename; +} + // // Model utils diff --git a/common/common.h b/common/common.h index 038f9084f..2345d855e 100644 --- a/common/common.h +++ b/common/common.h @@ -277,6 +277,7 @@ bool fs_validate_filename(const std::string & filename); bool fs_create_directory_with_parents(const std::string & path); std::string fs_get_cache_directory(); +std::string fs_get_cache_file(const std::string & filename); // // Model utils From fe1e3917cfa0f9397a765cfd0aef880674d938d5 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 9 Jun 2024 01:43:39 +0200 Subject: [PATCH 02/15] Revert "[SYCL] Update rpc-server.cpp to include SYCL backend (#7682)" (#7808) This reverts commit 9422c5e34bbd302493b77a8f6d546154a1f4fe82. --- examples/rpc/rpc-server.cpp | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 62d828250..7c15d2aa4 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -6,10 +6,6 @@ #include "ggml-metal.h" #endif -#ifdef GGML_USE_SYCL -#include "ggml-sycl.h" -#endif - #include "ggml-rpc.h" #ifdef _WIN32 # include @@ -83,12 +79,6 @@ static ggml_backend_t create_backend() { if (!backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); } -#elif GGML_USE_SYCL - fprintf(stderr, "%s: using SYCL backend\n", __func__); - backend = ggml_backend_sycl_init(0); // init device 0 - if (!backend) { - fprintf(stderr, "%s: ggml_backend_sycl_init() failed\n", __func__); - } #endif // if there aren't GPU Backends fallback to CPU backend From ed9f2521185706481501a5e6d5315397b11802ff Mon Sep 17 00:00:00 2001 From: compilade Date: Sat, 8 Jun 2024 22:34:29 -0400 Subject: [PATCH 03/15] gguf-py : decouple adding metadata from writing in GGUFWriter (#7827) Main changes of this PR is to consolidate GGUFWriter.add_key and GGUFWriter.add_val into GGUFWriter.add_key_value. In addition use_temp_file is now opt-in instead of opt-out defaulting to False. Also GGUFWriter now does not require output file name until when actually writing to it. And GGUFWriter doesn't really need to eagerly prepare the data layout of the metadata --- convert-hf-to-gguf.py | 8 +- gguf-py/gguf/gguf_writer.py | 270 +++++++++++++++------------ gguf-py/scripts/gguf-new-metadata.py | 6 +- 3 files changed, 160 insertions(+), 124 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a86864f04..0327712d7 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -47,7 +47,7 @@ class Model: _model_classes: dict[str, type[Model]] = {} dir_model: Path - ftype: int + ftype: gguf.LlamaFileType is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool @@ -94,7 +94,7 @@ class Model: ftype_lw: str = ftype_up.lower() # allow templating the file name with the output ftype, useful with the "auto" ftype self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) - self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) @classmethod def __init_subclass__(cls): @@ -324,13 +324,13 @@ class Model: def write(self): self.write_tensors() - self.gguf_writer.write_header_to_file() + self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.close() def write_vocab(self): - self.gguf_writer.write_header_to_file() + self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b93747aff..ed56abfb3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -5,6 +5,7 @@ import os import shutil import struct import tempfile +from dataclasses import dataclass from enum import Enum, auto from io import BufferedWriter from typing import IO, Any, Sequence, Mapping @@ -30,17 +31,36 @@ from .quants import quant_shape_from_byte_shape logger = logging.getLogger(__name__) +@dataclass +class TensorInfo: + shape: Sequence[int] + dtype: GGMLQuantizationType + nbytes: int + tensor: np.ndarray[Any, Any] | None = None + + +@dataclass +class GGUFValue: + value: Any + type: GGUFValueType + + class WriterState(Enum): + NO_FILE = auto() EMPTY = auto() HEADER = auto() KV_DATA = auto() TI_DATA = auto() + WEIGHTS = auto() class GGUFWriter: - fout: BufferedWriter + fout: BufferedWriter | None + path: os.PathLike[str] | str | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: list[np.ndarray[Any, Any]] + tensors: dict[str, TensorInfo] + kv_data: dict[str, GGUFValue] + state: WriterState _simple_value_packing = { GGUFValueType.UINT8: "B", GGUFValueType.INT8: "b", @@ -56,141 +76,140 @@ class GGUFWriter: } def __init__( - self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, + self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, ): - self.fout = open(path, "wb") + self.fout = None + self.path = path self.arch = arch self.endianess = endianess - self.offset_tensor = 0 self.data_alignment = GGUF_DEFAULT_ALIGNMENT - self.kv_data = bytearray() - self.kv_data_count = 0 - self.ti_data = bytearray() - self.ti_data_count = 0 - self.ti_names = set() self.use_temp_file = use_temp_file self.temp_file = None - self.tensors = [] + self.tensors = dict() + self.kv_data = dict() logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) - self.state = WriterState.EMPTY + self.state = WriterState.NO_FILE self.add_architecture() - def write_header_to_file(self) -> None: + def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: + if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): + # allow calling this multiple times as long as the path is the same + return + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') + + if path is not None: + self.path = path + + if self.path is not None: + if self.fout is not None: + self.fout.close() + self.fout = open(self.path, "wb") + self.state = WriterState.EMPTY + + def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: + self.open_output_file(path) + if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') self._write_packed(" None: if self.state is not WriterState.HEADER: raise ValueError(f'Expected output file to contain the header, got {self.state}') + assert self.fout is not None - self.fout.write(self.kv_data) + kv_data = bytearray() + + for key, val in self.kv_data.items(): + kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_data += self._pack_val(val.value, val.type, add_vtype=True) + + self.fout.write(kv_data) self.flush() self.state = WriterState.KV_DATA def write_ti_data_to_file(self) -> None: if self.state is not WriterState.KV_DATA: raise ValueError(f'Expected output file to contain KV data, got {self.state}') + assert self.fout is not None - self.fout.write(self.ti_data) + ti_data = bytearray() + offset_tensor = 0 + + for name, ti in self.tensors.items(): + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) + n_dims = len(ti.shape) + ti_data += self._pack("I", n_dims) + for i in range(n_dims): + ti_data += self._pack("Q", ti.shape[n_dims - 1 - i]) + ti_data += self._pack("I", ti.dtype) + ti_data += self._pack("Q", offset_tensor) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + + self.fout.write(ti_data) self.flush() self.state = WriterState.TI_DATA - def add_key(self, key: str) -> None: - self.add_val(key, GGUFValueType.STRING, add_vtype=False) + def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: + if key in self.kv_data: + raise ValueError(f'Duplicated key name {key!r}') + + self.kv_data[key] = GGUFValue(value=val, type=vtype) def add_uint8(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT8) + self.add_key_value(key,val, GGUFValueType.UINT8) def add_int8(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT8) + self.add_key_value(key, val, GGUFValueType.INT8) def add_uint16(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT16) + self.add_key_value(key, val, GGUFValueType.UINT16) def add_int16(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT16) + self.add_key_value(key, val, GGUFValueType.INT16) def add_uint32(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT32) + self.add_key_value(key, val, GGUFValueType.UINT32) def add_int32(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT32) + self.add_key_value(key, val, GGUFValueType.INT32) def add_float32(self, key: str, val: float) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.FLOAT32) + self.add_key_value(key, val, GGUFValueType.FLOAT32) def add_uint64(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT64) + self.add_key_value(key, val, GGUFValueType.UINT64) def add_int64(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT64) + self.add_key_value(key, val, GGUFValueType.INT64) def add_float64(self, key: str, val: float) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.FLOAT64) + self.add_key_value(key, val, GGUFValueType.FLOAT64) def add_bool(self, key: str, val: bool) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.BOOL) + self.add_key_value(key, val, GGUFValueType.BOOL) def add_string(self, key: str, val: str) -> None: if not val: return - self.add_key(key) - self.add_val(val, GGUFValueType.STRING) + self.add_key_value(key, val, GGUFValueType.STRING) def add_array(self, key: str, val: Sequence[Any]) -> None: if not isinstance(val, Sequence): raise ValueError("Value must be a sequence for array type") - self.add_key(key) - self.add_val(val, GGUFValueType.ARRAY) - - def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: - if vtype is None: - vtype = GGUFValueType.get_type(val) - - if add_vtype: - self.kv_data += self._pack("I", vtype) - self.kv_data_count += 1 - - pack_fmt = self._simple_value_packing.get(vtype) - if pack_fmt is not None: - self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) - elif vtype == GGUFValueType.STRING: - encoded_val = val.encode("utf-8") if isinstance(val, str) else val - self.kv_data += self._pack("Q", len(encoded_val)) - self.kv_data += encoded_val - elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: - ltype = GGUFValueType.get_type(val[0]) - if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): - raise ValueError("All items in a GGUF array should be of the same type") - self.kv_data += self._pack("I", ltype) - self.kv_data += self._pack("Q", len(val)) - for item in val: - self.add_val(item, add_vtype=False) - else: - raise ValueError("Invalid GGUF metadata value type or value") + self.add_key_value(key, val, GGUFValueType.ARRAY) @staticmethod def ggml_pad(x: int, n: int) -> int: @@ -200,16 +219,12 @@ class GGUFWriter: self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, ) -> None: - if self.state is not WriterState.EMPTY: - raise ValueError(f'Expected output file to be empty, got {self.state}') + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') - if name in self.ti_names: - raise ValueError(f'Duplicated tensor name {name}') - self.ti_names.add(name) + if name in self.tensors: + raise ValueError(f'Duplicated tensor name {name!r}') - encoded_name = name.encode("utf-8") - self.ti_data += self._pack("Q", len(encoded_name)) - self.ti_data += encoded_name if raw_dtype is None: if tensor_dtype == np.float16: dtype = GGMLQuantizationType.F16 @@ -231,14 +246,8 @@ class GGUFWriter: dtype = raw_dtype if tensor_dtype == np.uint8: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) - n_dims = len(tensor_shape) - self.ti_data += self._pack("I", n_dims) - for i in range(n_dims): - self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) - self.ti_data += self._pack("I", dtype) - self.ti_data += self._pack("Q", self.offset_tensor) - self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) - self.ti_data_count += 1 + + self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, @@ -252,10 +261,10 @@ class GGUFWriter: self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) if self.temp_file is None: - self.tensors.append(tensor) + self.tensors[name].tensor = tensor return tensor.tofile(self.temp_file) @@ -267,8 +276,9 @@ class GGUFWriter: fp.write(bytes([0] * pad)) def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: - if self.state is not WriterState.TI_DATA: - raise ValueError(f'Expected output file to contain tensor info, got {self.state}') + if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: + raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') + assert self.fout is not None if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) @@ -276,50 +286,51 @@ class GGUFWriter: tensor.tofile(self.fout) self.write_padding(self.fout, tensor.nbytes) + self.state = WriterState.WEIGHTS + def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_ti_data_to_file() + assert self.fout is not None + self.write_padding(self.fout, self.fout.tell()) if self.temp_file is None: - self.tensors.reverse() # to pop from the "beginning" in constant time + bar = None if progress: from tqdm import tqdm - total_bytes = sum(t.nbytes for t in self.tensors) + total_bytes = sum(t.nbytes for t in self.tensors.values()) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - while True: - try: - tensor = self.tensors.pop() - except IndexError: - break - tensor.tofile(self.fout) - bar.update(tensor.nbytes) - self.write_padding(self.fout, tensor.nbytes) - return - while True: - try: - tensor = self.tensors.pop() - except IndexError: - break - tensor.tofile(self.fout) - self.write_padding(self.fout, tensor.nbytes) - return + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in self.tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + ti.tensor.tofile(self.fout) + if bar is not None: + bar.update(ti.nbytes) + self.write_padding(self.fout, ti.nbytes) + ti.tensor = None + else: + self.temp_file.seek(0) - self.temp_file.seek(0) + shutil.copyfileobj(self.temp_file, self.fout) + self.flush() + self.temp_file.close() - shutil.copyfileobj(self.temp_file, self.fout) - self.flush() - self.temp_file.close() + self.state = WriterState.WEIGHTS def flush(self) -> None: + assert self.fout is not None self.fout.flush() def close(self) -> None: - self.fout.close() + if self.fout is not None: + self.fout.close() + self.fout = None def add_architecture(self) -> None: self.add_string(Keys.General.ARCHITECTURE, self.arch) @@ -449,7 +460,7 @@ class GGUFWriter: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) - def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None: + def add_rope_scaling_attn_factors(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) def add_rope_scaling_orig_ctx_len(self, value: int) -> None: @@ -571,5 +582,32 @@ class GGUFWriter: pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' return struct.pack(f'{pack_prefix}{fmt}', value) + def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes: + kv_data = bytearray() + + if add_vtype: + kv_data += self._pack("I", vtype) + + pack_fmt = self._simple_value_packing.get(vtype) + if pack_fmt is not None: + kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) + elif vtype == GGUFValueType.STRING: + encoded_val = val.encode("utf-8") if isinstance(val, str) else val + kv_data += self._pack("Q", len(encoded_val)) + kv_data += encoded_val + elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") + kv_data += self._pack("I", ltype) + kv_data += self._pack("Q", len(val)) + for item in val: + kv_data += self._pack_val(item, ltype, add_vtype=False) + else: + raise ValueError("Invalid GGUF metadata value type or value") + + return kv_data + def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: + assert self.fout is not None self.fout.write(self._pack(fmt, value, skip_pack_prefix)) diff --git a/gguf-py/scripts/gguf-new-metadata.py b/gguf-py/scripts/gguf-new-metadata.py index 21e91180c..c4b90d581 100755 --- a/gguf-py/scripts/gguf-new-metadata.py +++ b/gguf-py/scripts/gguf-new-metadata.py @@ -101,8 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Copying {field.name}') if val.value is not None: - writer.add_key(field.name) - writer.add_val(val.value, val.type) + writer.add_key_value(field.name, val.value, val.type) if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: logger.debug('Adding chat template(s)') @@ -111,8 +110,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new for key, val in new_metadata.items(): logger.debug(f'Adding {key}: "{val.value}" {val.description}') - writer.add_key(key) - writer.add_val(val.value, val.type) + writer.add_key_value(key, val.value, val.type) total_bytes = 0 From 5795b941827fdec6c1662986de962badff456718 Mon Sep 17 00:00:00 2001 From: compilade Date: Sat, 8 Jun 2024 22:47:25 -0400 Subject: [PATCH 04/15] convert-hf : match model part name prefix and suffix (#7687) In #7075, to fix the conversion of (some) models using model-00001-of-00001.safetensors instead of model.safetensors for a single model part we simply used the same logic as the part count to get the part names. But this doesn't always work correctly, like when unusual additional model files like consolidated.safetensors in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 are present. This commit matching both the prefix and the suffix of the model part names should fix this problem without breaking any previously-supported upstream models. But according to report by @teleprint-me there is still some persistent problem, but shall do in the meantime. --- convert-hf-to-gguf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0327712d7..b38f48edf 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -73,10 +73,10 @@ class Model: self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager - self.part_names = Model.get_model_part_names(self.dir_model, ".safetensors") + self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: - self.part_names = Model.get_model_part_names(self.dir_model, ".bin") + self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = Model.load_hparams(self.dir_model) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -335,10 +335,10 @@ class Model: self.gguf_writer.close() @staticmethod - def get_model_part_names(dir_model: Path, suffix: str) -> list[str]: + def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: part_names: list[str] = [] for filename in os.listdir(dir_model): - if filename.endswith(suffix): + if filename.startswith(prefix) and filename.endswith(suffix): part_names.append(filename) part_names.sort() From 2decf57bc6e4a6b45176c3727d964a01161beecc Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Sun, 9 Jun 2024 06:39:25 +0000 Subject: [PATCH 05/15] convert-hf : set the model name based on cli arg, if present (#7693) `--model-name` argument was added a while ago but did not do anything. This commit fixes this issue and enables this feature. --- convert-hf-to-gguf.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b38f48edf..025405a2c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -52,6 +52,7 @@ class Model: endianess: gguf.GGUFEndian use_temp_file: bool lazy: bool + model_name: str | None part_names: list[str] is_safetensors: bool hparams: dict[str, Any] @@ -64,7 +65,7 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -73,6 +74,7 @@ class Model: self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager + self.model_name = model_name self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: @@ -182,7 +184,7 @@ class Model: return new_name def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_block_count(self.block_count) if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: @@ -665,7 +667,7 @@ class GPTNeoXModel(Model): def set_gguf_parameters(self): block_count = self.hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -798,7 +800,7 @@ class MPTModel(Model): def set_gguf_parameters(self): block_count = self.hparams["n_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_block_count(block_count) @@ -850,7 +852,7 @@ class OrionModel(Model): raise ValueError("gguf: can not find ctx length parameter.") self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -887,7 +889,7 @@ class BaichuanModel(Model): else: raise ValueError("gguf: can not find ctx length parameter.") - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -1010,7 +1012,7 @@ class XverseModel(Model): else: raise ValueError("gguf: can not find ctx length parameter.") - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -1206,7 +1208,7 @@ class StableLMModel(Model): hparams = self.hparams block_count = hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -1681,7 +1683,7 @@ class GPT2Model(Model): model_arch = gguf.MODEL_ARCH.GPT2 def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_context_length(self.hparams["n_ctx"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) @@ -2248,7 +2250,7 @@ class GemmaModel(Model): hparams = self.hparams block_count = hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -2348,7 +2350,7 @@ class MambaModel(Model): # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading @@ -2852,7 +2854,7 @@ def main() -> None: logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name) logger.info("Set model parameters") model_instance.set_gguf_parameters() From 42b53d192f4e3abf1b7c8e424628424504ea5dc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 9 Jun 2024 09:42:25 +0200 Subject: [PATCH 06/15] CUDA: revise q8_1 data layout for mul_mat_q (#7824) --- ggml-cuda.cu | 88 +++++++++------ ggml-cuda/mmq.cu | 3 +- ggml-cuda/mmq.cuh | 236 ++++++++++++++++++++++------------------- ggml-cuda/quantize.cu | 89 ++++++++++++++-- ggml-cuda/quantize.cuh | 17 ++- 5 files changed, 282 insertions(+), 151 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dad8a9e2d..af10f21a0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { GGML_UNUSED(main_device); } +static cudaError_t ggml_cuda_Memcpy2DPeerAsync( + void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { + +#if !defined(GGML_USE_HIPBLAS) + // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices + cudaMemcpy3DPeerParms p = {}; + p.dstDevice = dstDevice; + p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height); + p.srcDevice = srcDevice; + p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height); + p.extent = make_cudaExtent(width, height, 1); + return cudaMemcpy3DPeerAsync(&p, stream); +#else + // HIP does not support cudaMemcpy3DPeerAsync or vmm pools + GGML_UNUSED(dstDevice); + GGML_UNUSED(srcDevice); + return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); +#endif // !defined(GGML_USE_HIPBLAS) +} + static void ggml_cuda_op_mul_mat( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, - const bool convert_src1_to_q8_1) { + quantize_cuda_t quantize_src1) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -1407,7 +1427,9 @@ static void ggml_cuda_op_mul_mat( } struct dev_data { - ggml_cuda_pool_alloc src0_dd_alloc; + int cc; + + ggml_cuda_pool_alloc src0_dd_alloc; ggml_cuda_pool_alloc src1_ddf_alloc; ggml_cuda_pool_alloc src1_ddq_alloc; ggml_cuda_pool_alloc dst_dd_alloc; @@ -1426,6 +1448,8 @@ static void ggml_cuda_op_mul_mat( int used_devices = 0; for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + dev[id].cc = ggml_cuda_info().devices[id].cc; + // by default, use all rows dev[id].row_low = 0; dev[id].row_high = ne01; @@ -1476,11 +1500,15 @@ static void ggml_cuda_op_mul_mat( dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1)); } - if (convert_src1_to_q8_1) { - dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq); + } + dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } } @@ -1526,7 +1554,12 @@ static void ggml_cuda_op_mul_mat( const int64_t i03 = i0 / ne12; const int64_t i02 = i0 % ne12; - const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; + size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq); + } else { + src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs; + } // for split tensors the data begins at i0 == i0_offset_low char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; @@ -1543,10 +1576,17 @@ static void ggml_cuda_op_mul_mat( // copy src0, src1 to device if necessary if (src1_is_contiguous) { if (id != ctx.device) { - if (convert_src1_to_q8_1) { + if (quantize_src1) { char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; - CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device, - src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + const size_t pitch = ne11*sizeof(block_q8_1_mmq); + const size_t width = src1_ncols*sizeof(block_q8_1_mmq); + const size_t height = src1_padded_col_size/(4*QK8_1); + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream)); + } else { + CUDA_CHECK(cudaMemcpyPeerAsync( + src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + } } else { float * src1_ddf_i_source = (float *) src1->data; src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; @@ -1561,8 +1601,8 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(false); } - if (convert_src1_to_q8_1 && !src1_is_contiguous) { - quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + if (quantize_src1 && !src1_is_contiguous) { + quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1587,22 +1627,8 @@ static void ggml_cuda_op_mul_mat( float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); dhf_dst_i += src1_col_0*ne0 + dev[id].row_low; -#if !defined(GGML_USE_HIPBLAS) - // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices - cudaMemcpy3DPeerParms p = {}; - p.dstDevice = ctx.device; - p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols); - p.srcDevice = id; - p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols); - p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1); - CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream)); -#else - // HIP does not support cudaMemcpy3DPeerAsync or vmm pools - CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), - dst_dd_i, row_diff*sizeof(float), - row_diff*sizeof(float), src1_ncols, - cudaMemcpyDeviceToDevice, stream)); -#endif + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync( + dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream)); } else { float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); @@ -1941,13 +1967,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor // KQ + KQV multi-batch ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); } else { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } } diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 58799e4ca..1d6b9e698 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -11,6 +11,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t nb01 = src0->nb[1]; const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; @@ -25,7 +26,7 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst}; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; switch (src0->type) { case GGML_TYPE_Q4_0: diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 6744cce6d..3ccae8a0c 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -1,15 +1,26 @@ +#pragma once + #include "common.cuh" #include "vecdotq.cuh" #include #include +#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) + typedef void (*load_tiles_mmq_t)( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); typedef void (*vec_dot_mmq_t)( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0); + const int * __restrict__ y, float * __restrict__ sum, const int & k0); + +struct block_q8_1_mmq { + half2 ds[4]; + int8_t qs[4*QK8_1]; +}; +static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); struct tile_x_sizes { int ql; @@ -132,10 +143,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -145,19 +160,18 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( const int i = i0 + threadIdx.x; const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const float * x_dmf = (const float *) x_dm; int u[2*VDR_Q4_0_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], + y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -203,10 +217,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -221,13 +238,13 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], + y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -293,10 +310,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -306,20 +327,18 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( const int i = i0 + threadIdx.x; const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; + const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; int u[2*VDR_Q5_0_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -383,10 +402,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -396,18 +418,18 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( const int i = i0 + threadIdx.x; const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1; + const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1; int u[2*VDR_Q5_1_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE]; } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } @@ -455,10 +477,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -467,12 +493,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]); + (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], + y_df[j*MMQ_TILE_Y_K + k0/QI8_1]); } } } @@ -531,10 +554,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -545,11 +571,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( const int kbx = k0 / QI2_K; const int ky = (k0 % QI2_K) * QR2_K; - const float * y_df = (const float *) y_ds; int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); #pragma unroll @@ -557,11 +582,11 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; } - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; - const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales, + x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -646,7 +671,11 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -658,8 +687,6 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( const int kbx = k0 / QI3_K; const int ky = (k0 % QI3_K) * QR3_K; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; @@ -667,19 +694,19 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( #pragma unroll for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); const int shift = 2 * ((ky % 32) / 8); const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); const int vlh = (vh << 2) & 0x04040404; v[l] = __vsubss4(vll, vlh); } - const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales, + x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]); } } } @@ -746,10 +773,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -760,9 +790,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat( const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); - const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8, + x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -842,10 +872,13 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -856,10 +889,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat( const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0; - const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8, + x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -932,10 +964,14 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); + const float * x_dmf = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -944,15 +980,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0; - const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, + x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -964,7 +996,6 @@ struct mmq_type_traits; template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; @@ -972,7 +1003,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; @@ -980,7 +1010,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; @@ -988,7 +1017,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; @@ -996,7 +1024,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; @@ -1004,7 +1031,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; @@ -1012,7 +1038,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; @@ -1020,7 +1045,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; @@ -1028,7 +1052,6 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = true; static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; @@ -1036,12 +1059,36 @@ struct mmq_type_traits { template struct mmq_type_traits { - static constexpr bool need_sum = false; static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; }; +static int mmq_need_sum(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return true; + case GGML_TYPE_Q5_0: + return false; + case GGML_TYPE_Q5_1: + return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + return false; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return true; + case GGML_TYPE_Q6_K: + return false; + default: + GGML_ASSERT(false); + break; + } + return false; +} + template #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) @@ -1056,7 +1103,7 @@ template #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, - const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) { + const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device()) { @@ -1068,7 +1115,6 @@ static __global__ void mul_mat_q( constexpr int qr = ggml_cuda_type_traits::qr; constexpr int qi = ggml_cuda_type_traits::qi; constexpr int mmq_y = get_mmq_y_device(mmq_x); - constexpr bool need_sum = mmq_type_traits::need_sum; constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; @@ -1080,62 +1126,38 @@ static __global__ void mul_mat_q( half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql); int * tile_x_qh = (int *) (tile_x_dm + txs.dm); int * tile_x_sc = (int *) (tile_x_qh + txs.qh); - int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE] - half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1]; - - const block_q8_1 * y = (const block_q8_1 *) yc; + int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] const int blocks_per_row_x = ne00 / qk; - const int blocks_per_col_y = ne10 / QK8_1; const int blocks_per_warp = WARP_SIZE / qi; const int & ne1 = ne11; const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1; + const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); + float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { - load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00); + load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01); #pragma unroll for (int kr = 0; kr < qr; ++kr) { - const int kqs = kr*WARP_SIZE + threadIdx.x; - const int kbxd = kqs / QI8_1; - + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < mmq_x; i0 += nwarps) { - const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { + int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; - const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd]; - - const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE; - tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); - } - -#pragma unroll - for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); - const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1); - - // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; - if (need_sum) { - *dsi_dst = *dsi_src; - } else { - float * dfi_dst = (float *) dsi_dst; - *dfi_dst = __low2float(*dsi_src); - } + tile_y[l] = by0[l]; } __syncthreads(); // #pragma unroll // unrolling this loop causes too much register pressure for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0); + vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0); } __syncthreads(); @@ -1165,8 +1187,8 @@ static __global__ void mul_mat_q( struct mmq_args { const char * x; const char * y; float * dst; - int64_t ne00; int64_t ne01; int64_t stride00; - int64_t ne10; int64_t ne11; + int64_t ne00; int64_t ne01; int64_t stride01; + int64_t ne10; int64_t ne11; int64_t stride11; int64_t ne0; }; @@ -1184,7 +1206,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int); const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); - const int shmem = shmem_x + shmem_y; + const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int)); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1198,11 +1220,11 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { if (args.ne01 % mmq_y == 0) { const bool need_check = false; mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); } else { const bool need_check = true; mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); } } diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu index 7578c4b6c..b46786822 100644 --- a/ggml-cuda/quantize.cu +++ b/ggml-cuda/quantize.cu @@ -1,22 +1,23 @@ #include "quantize.cuh" +#include -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) { - const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { + const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; - if (ix >= kx_padded) { + if (ix0 >= kx0_padded) { return; } - const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y; + const int64_t ix1 = blockIdx.y; - const int64_t i_padded = (int64_t)iy*kx_padded + ix; + const int64_t i_padded = ix1*kx0_padded + ix0; block_q8_1 * y = (block_q8_1 *) vy; const int64_t ib = i_padded / QK8_1; // block index const int64_t iqs = i_padded % QK8_1; // quant index - const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) { - const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, ky, 1); - const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx, kx_padded); +template +static __global__ void quantize_mmq_q8_1( + const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { + + const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (ix0 >= kx0_padded) { + return; + } + + const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; + + block_q8_1_mmq * y = (block_q8_1_mmq *) vy; + + const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel + const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel + const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + + const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f; + float amax = fabsf(xi); + + amax = warp_reduce_max(amax); + + float sum; + if (need_sum) { + sum = warp_reduce_sum(xi); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs % QK8_1 != 0) { + return; + } + + if (need_sum) { + y[ib].ds[iqs/QK8_1] = make_half2(d, sum); + } else { + ((float *) y[ib].ds)[iqs/QK8_1] = d; + } } +void quantize_row_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, + const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + + GGML_ASSERT(kx0_padded % QK8_1 == 0); + + const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, kx1*channels, 1); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<>>(x, vy, kx0, kx0_padded); + + GGML_UNUSED(type_x); +} + +void quantize_mmq_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, + const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + + GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); + + const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, kx1, channels); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); + if (mmq_need_sum(type_x)) { + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + } else { + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + } +} diff --git a/ggml-cuda/quantize.cuh b/ggml-cuda/quantize.cuh index b37a4752f..486c9360a 100644 --- a/ggml-cuda/quantize.cuh +++ b/ggml-cuda/quantize.cuh @@ -1,5 +1,20 @@ +#pragma once + #include "common.cuh" +#include "mmq.cuh" + +#include #define CUDA_QUANTIZE_BLOCK_SIZE 256 -void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream); +typedef void (*quantize_cuda_t)( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); + +void quantize_row_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); + +void quantize_mmq_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); From 3e2ee443159724e2d3a0741f6b167e599ec088aa Mon Sep 17 00:00:00 2001 From: mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> Date: Sun, 9 Jun 2024 12:50:35 +0200 Subject: [PATCH 07/15] server: do not remove whitespace at the start of a completion chunk (#7830) --- examples/server/public/index-new.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/public/index-new.html b/examples/server/public/index-new.html index d571c2779..19c9f643d 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public/index-new.html @@ -416,7 +416,7 @@ message = html`<${Probabilities} data=${data} />` } else { const text = isArrayMessage ? - data.map(msg => msg.content).join('').replace(/^\s+/, '') : + data.map(msg => msg.content).join('') : data; message = isCompletionMode ? text : From 57bf62ce7cb75cca589943e2050d29bff4026e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=A1s=20P=C3=A9rez?= Date: Sun, 9 Jun 2024 11:24:29 -0400 Subject: [PATCH 08/15] docs: Added initial PR template with directions for doc only changes and squash merges [no ci] (#7700) This commit adds pull_request_template.md and CONTRIBUTING.md . It focuses on explaining to contributors the need to rate PR complexity level, when to add [no ci] and how to format PR title and descriptions. Co-authored-by: Brian Co-authored-by: compilade --- .../PULL_REQUEST_TEMPLATE/pull_request_template.md | 5 +++++ CONTRIBUTING.md | 14 ++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/pull_request_template.md create mode 100644 CONTRIBUTING.md diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md new file mode 100644 index 000000000..0852fded5 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -0,0 +1,5 @@ +- Self Reported Review Complexity: + - [ ] Review Complexity : Low + - [ ] Review Complexity : Medium + - [ ] Review Complexity : High +- [ ] I have read the [contributing guidelines](CONTRIBUTING.md) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..991d85e49 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,14 @@ +# Contributing Guidelines + +## Checklist + +* Make sure your PR follows the [coding guidelines](https://github.com/ggerganov/llama.cpp/blob/master/README.md#coding-guidelines) +* Test your changes using the commands in the [`tests`](tests) folder. For instance, running the `./tests/test-backend-ops` command tests different backend implementations of the GGML library +* Execute [the full CI locally on your machine](ci/README.md) before publishing + +## PR formatting + +* Please rate the complexity of your PR (i.e. `Review Complexity : Low`, `Review Complexity : Medium`, `Review Complexity : High`). This makes it easier for maintainers to triage the PRs. + - The PR template has a series of review complexity checkboxes `[ ]` that you can mark as `[X]` for your conveience. Refer to [About task lists](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/about-task-lists) for more information. +* If the pull request only contains documentation changes (e.g., updating READMEs, adding new wiki pages), please add `[no ci]` to the commit title. This will skip unnecessary CI checks and help reduce build times. +* When squashing multiple commits on merge, use the following format for your commit title: ` : (#)`. For example: `utils : Fix typo in utils.py (#1234)` From e95beeb1fc4621826ddd616776dbdf717366bf5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jun 2024 20:19:35 +0300 Subject: [PATCH 09/15] imatrix : handle partial entries (#7833) --- examples/imatrix/imatrix.cpp | 58 +++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e18f49563..574f5ed9c 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const { fname += std::to_string(ncall); } + // avoid writing imatrix entries that do not have full data + // this can happen with MoE models where some of the experts end up not being exercised by the provided training data + + int n_entries = 0; + std::vector to_store; + + bool is_first = true; // for printing + for (const auto & kv : m_stats) { + const int n_all = kv.second.counts.size(); + + if (n_all == 0) { + continue; + } + + int n_zeros = 0; + for (const int c : kv.second.counts) { + if (c == 0) { + n_zeros++; + } + } + + if (n_zeros != 0 && is_first) { + fprintf(stderr, "\n"); + is_first = false; + } + + if (n_zeros == n_all) { + fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str()); + continue; + } + + if (n_zeros > 0) { + fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); + continue; + } + + n_entries++; + to_store.push_back(kv.first); + } + + if (to_store.size() < m_stats.size()) { + fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size()); + } + std::ofstream out(fname, std::ios::binary); - int n_entries = m_stats.size(); out.write((const char *) &n_entries, sizeof(n_entries)); - for (const auto & p : m_stats) { - int len = p.first.size(); + for (const auto & name : to_store) { + const auto & stat = m_stats.at(name); + int len = name.size(); out.write((const char *) &len, sizeof(len)); - out.write(p.first.c_str(), len); - out.write((const char *) &p.second.ncall, sizeof(p.second.ncall)); - int nval = p.second.values.size(); + out.write(name.c_str(), len); + out.write((const char *) &stat.ncall, sizeof(stat.ncall)); + int nval = stat.values.size(); out.write((const char *) &nval, sizeof(nval)); if (nval > 0) { std::vector tmp(nval); for (int i = 0; i < nval; i++) { - tmp[i] = (p.second.values[i] / static_cast(p.second.counts[i])) * static_cast(p.second.ncall); + tmp[i] = (stat.values[i] / static_cast(stat.counts[i])) * static_cast(stat.ncall); } out.write((const char*)tmp.data(), nval*sizeof(float)); } From 10ceba354a3b152ff425e9fa97f9caaef99a46b1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Jun 2024 02:04:50 +0300 Subject: [PATCH 10/15] flake.lock: Update (#7838) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flake lock file updates: • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/ad57eef4ef0659193044870c731987a6df5cf56b?narHash=sha256-SzDKxseEcHR5KzPXLwsemyTR/kaM9whxeiJohbL04rs%3D' (2024-05-29) → 'github:NixOS/nixpkgs/051f920625ab5aabe37c920346e3e69d7d34400e?narHash=sha256-4q0s6m0GUcN7q%2BY2DqD27iLvbcd1G50T2lv08kKxkSI%3D' (2024-06-07) Co-authored-by: github-actions[bot] --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 09047ab10..7272e65fa 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1716948383, - "narHash": "sha256-SzDKxseEcHR5KzPXLwsemyTR/kaM9whxeiJohbL04rs=", + "lastModified": 1717786204, + "narHash": "sha256-4q0s6m0GUcN7q+Y2DqD27iLvbcd1G50T2lv08kKxkSI=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "ad57eef4ef0659193044870c731987a6df5cf56b", + "rev": "051f920625ab5aabe37c920346e3e69d7d34400e", "type": "github" }, "original": { From af4ae502ddaeb03cd5861273ca2e9a5ae4551db7 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 10 Jun 2024 02:21:31 -0700 Subject: [PATCH 11/15] use the correct SYCL context for host USM allocations (#7777) Signed-off-by: Ben Ashbaugh --- ggml-sycl.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 0a645b2e1..42fc0df20 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -13089,10 +13089,12 @@ void *ggml_sycl_host_malloc(size_t size) try { return nullptr; } + ggml_sycl_set_device(g_main_device); + dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + void * ptr = nullptr; - //allow to use dpct::get_in_order_queue() for host malloc dpct::err0 err = CHECK_TRY_ERROR( - ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue())); + ptr = (void *)sycl::malloc_host(size, *main_stream)); if (err != 0) { // clear the error @@ -13113,8 +13115,9 @@ catch (sycl::exception const &exc) { } void ggml_sycl_host_free(void *ptr) try { - //allow to use dpct::get_in_order_queue() for host malloc - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue()))); + ggml_sycl_set_device(g_main_device); + dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream))); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ From 1f0dabda8d5c131f9d4632aa41de74317cdd61fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 10 Jun 2024 11:45:13 +0200 Subject: [PATCH 12/15] CUDA: use tensor cores for MMQ (#7676) * CUDA: int8 tensor cores for MMQ (legacy quants) * fix out-of-bounds writes * __builtin_assume -> GGML_CUDA_ASSUME * fix writeback returning too early --- ggml-cuda/common.cuh | 21 +- ggml-cuda/fattn-common.cuh | 20 +- ggml-cuda/fattn-tile-f16.cu | 2 +- ggml-cuda/fattn-vec-f16.cuh | 2 +- ggml-cuda/fattn-wmma-f16.cuh | 6 +- ggml-cuda/mma.cuh | 95 ++++++++ ggml-cuda/mmq.cuh | 459 ++++++++++++++++++++++++++++++++--- 7 files changed, 550 insertions(+), 55 deletions(-) create mode 100644 ggml-cuda/mma.cuh diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 90a0a81ea..7f4764d60 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -139,6 +139,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_TURING 750 #define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) @@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 #endif // defined(GGML_USE_HIPBLAS) -#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#define FP16_AVAILABLE +#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL -#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#define FP16_MMA_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#define INT8_MMA_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING static bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; @@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; } +static bool int8_mma_available(const int cc) { + return cc < CC_OFFSET_AMD && cc >= CC_TURING; +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { @@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #pragma unroll @@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { } static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX return __float2half(fmaxf(__half2float(a), __half2float(b))); diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index c00f8606a..37b3b9932 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const int sumi = __dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_h2 = (const half2 *) Q_v; @@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F) - 8; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } @@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } @@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh) - 16; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } @@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } @@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ const T d = x[ib].d; const int q = x[ib].qs[iqs]; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index cb11d7212..c6c35134d 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. diff --git a/ggml-cuda/fattn-vec-f16.cuh b/ggml-cuda/fattn-vec-f16.cuh index 9e1aa2c6b..02a4ad072 100644 --- a/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml-cuda/fattn-vec-f16.cuh @@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); diff --git a/ggml-cuda/fattn-wmma-f16.cuh b/ggml-cuda/fattn-wmma-f16.cuh index 59cd30d78..ae2322242 100644 --- a/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml-cuda/fattn-wmma-f16.cuh @@ -1,9 +1,9 @@ #include "common.cuh" #include "fattn-common.cuh" -#if FP16_MMA_AVAILABLE +#ifdef FP16_MMA_AVAILABLE #include -#endif +#endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_MMA_AVAILABLE +#ifdef FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. diff --git a/ggml-cuda/mma.cuh b/ggml-cuda/mma.cuh new file mode 100644 index 000000000..71e8e3429 --- /dev/null +++ b/ggml-cuda/mma.cuh @@ -0,0 +1,95 @@ +#include "common.cuh" + +struct mma_int_A_I16K8 { + static constexpr int I = 16; + static constexpr int K = 8; + static constexpr int ne = 4; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_k(const int l) { + const int ret = (l/2) * (K/2) + threadIdx.x % (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } +}; + +struct mma_int_B_J8K8 { + static constexpr int J = 8; + static constexpr int K = 8; + static constexpr int ne = 2; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_j(const int /* l */) { + const int ret = threadIdx.x / (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + static __device__ __forceinline__ int get_k(const int l) { + const int ret = l * (K/2) + threadIdx.x % (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } +}; + +struct mma_int_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 8; + static constexpr int ne = 4; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int l) { + const int ret = 2 * (threadIdx.x % (J/2)) + l%2; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); +#else + // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[0]), "+r"(x[1]) + : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[0]), "+r"(x[1]) + : "r"(mma_A.x[2]), "r"(mma_B.x[1])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[3]), "r"(mma_B.x[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } +}; diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 3ccae8a0c..62111f376 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -2,6 +2,7 @@ #include "common.cuh" #include "vecdotq.cuh" +#include "mma.cuh" #include #include @@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)( typedef void (*vec_dot_mmq_t)( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); struct block_q8_1_mmq { half2 ds[4]; @@ -141,15 +143,15 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], + (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } +template +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + mma_A A; + float dA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); + + A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + half2 dsB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; + } + } +} + template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -215,7 +281,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + mma_A A; + half2 dmA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l) % QI4_0; + const int shift = 4*(mma_A::get_k(l) / QI4_0); + + A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + half2 dsB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; + sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } + } +} + template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -308,7 +438,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + mma_A A; + float dA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0; + + A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + float dB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; + } + } +} template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, @@ -400,7 +592,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + mma_A A; + half2 dmA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1; + + A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + half2 dsB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; + sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + } + } +} + template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -475,7 +730,7 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { @@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + mma_A A; + float dA[mma_C::ne/2]; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + +#pragma unroll + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = k0 + mma_A::get_k(l); + + A.x[l] = x_ql[i*(WARP_SIZE + 1) + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0]; + } + + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C; + mma_B B; + float dB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = k0 + mma_B::get_k(l); + + B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; + } + + C.mma_K8(A, B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; + } + } +} + template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { @@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( } } +template +static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; + + if (j >= ne1) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; + + if (need_check && i >= ne0) { + continue; + } + + dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } +} + +template +static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { + typedef mma_int_C_I16J8 mma_C; + + const int i0 = threadIdx.y*mma_C::I; + static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l); + + if (j >= ne1) { + continue; + } + + const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l); + + if (need_check && i >= ne0) { + continue; + } + + dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; + } + } +} + // ------------------------------------------------------------------------------------------------------------------------------------- template @@ -998,35 +1367,65 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; +#ifdef INT8_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma; + static constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE }; template @@ -1034,6 +1433,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1041,6 +1441,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1048,6 +1449,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1055,6 +1457,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; template @@ -1062,6 +1465,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; + static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; }; static int mmq_need_sum(const ggml_type type_x) { @@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q( constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; + constexpr mmq_write_back_t write_back = mmq_type_traits::write_back; constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); @@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q( const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); - float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; + float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { @@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q( } } -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; - - if (j >= ne1) { - return; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; - - if (need_check && i >= ne0) { - continue; - } - - dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; - } - } + write_back(sum, dst, ne0, ne1); } struct mmq_args { @@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { launch_mul_mat_q(args, stream); break; case 16: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 24: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 32: launch_mul_mat_q(args, stream); From d9da0e4986f121c727bdd9579a6688097b11602c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Jun 2024 14:59:55 +0300 Subject: [PATCH 13/15] server : improve "prompt" handling (#7847) --- examples/server/server.cpp | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6ffaa8d9f..80714fa58 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -147,7 +147,7 @@ struct server_slot { int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; - json prompt; + std::string prompt; // when a task is submitted, we first tokenize the prompt and store it here std::vector prompt_tokens; @@ -822,13 +822,8 @@ struct server_context { continue; } - // skip the slot if it does not contains prompt - if (!slot.prompt.is_string()) { - continue; - } - // current slot's prompt - std::string slot_prompt = slot.prompt.get(); + std::string slot_prompt = slot.prompt; // length of the current slot's prompt int slot_prompt_len = slot_prompt.size(); @@ -958,13 +953,16 @@ struct server_context { if (!task.infill) { const auto & prompt = data.find("prompt"); if (prompt == data.end()) { - send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); return false; - } else { - slot.prompt = *prompt; } - if (slot.prompt.is_array() && slot.prompt.size() == 0) { - send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST); + + if (prompt->is_string()) { + slot.prompt = prompt->get(); + } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) { + slot.prompt = prompt->at(0).get(); + } else { + send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST); return false; } } @@ -1582,14 +1580,18 @@ struct server_context { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: { - int id_slot = json_value(task.data, "id_slot", -1); - std::string prompt = json_value(task.data, "prompt", std::string()); + const int id_slot = json_value(task.data, "id_slot", -1); server_slot * slot; if (id_slot != -1) { slot = get_slot_by_id(id_slot); } else { + std::string prompt; + if (task.data.contains("prompt") && task.data.at("prompt").is_string()) { + json_value(task.data, "prompt", std::string()); + } + slot = get_available_slot(prompt); } From c28a83902cf6ab6a9e085ad6d4cc2e95c4ccfe40 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Jun 2024 15:00:15 +0300 Subject: [PATCH 14/15] examples : remove --instruct remnants (#7846) --- README.md | 29 ----------------------------- examples/alpaca.sh | 19 ------------------- examples/gpt4all.sh | 15 --------------- examples/llama2-13b.sh | 18 ------------------ examples/llama2.sh | 18 ------------------ 5 files changed, 99 deletions(-) delete mode 100755 examples/alpaca.sh delete mode 100755 examples/gpt4all.sh delete mode 100755 examples/llama2-13b.sh delete mode 100755 examples/llama2.sh diff --git a/README.md b/README.md index 09e8cad31..ecb9d00db 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
  • Quantization
  • Interactive mode
  • Constrained output with grammars
  • -
  • Instruct mode
  • Obtaining and using the Facebook LLaMA 2 model
  • Seminal papers and background on the models
  • Perplexity (measuring model quality)
  • @@ -769,34 +768,6 @@ The `grammars/` folder contains a handful of sample grammars. To write your own, For authoring more complex JSON grammars, you can also check out https://grammar.intrinsiclabs.ai/, a browser app that lets you write TypeScript interfaces which it compiles to GBNF grammars that you can save for local use. Note that the app is built and maintained by members of the community, please file any issues or FRs on [its repo](http://github.com/intrinsiclabsai/gbnfgen) and not this one. -### Instruct mode - -1. First, download and place the `ggml` model into the `./models` folder -2. Run the `main` tool like this: - -``` -./examples/alpaca.sh -``` - -Sample run: - -``` -== Running in interactive mode. == - - Press Ctrl+C to interject at any time. - - Press Return to return control to LLaMA. - - If you want to submit another line, end your input in '\'. - - Below is an instruction that describes a task. Write a response that appropriately completes the request. - -> How many letters are there in the English alphabet? -There 26 letters in the English Alphabet -> What is the most common way of transportation in Amsterdam? -The majority (54%) are using public transit. This includes buses, trams and metros with over 100 lines throughout the city which make it very accessible for tourists to navigate around town as well as locals who commute by tram or metro on a daily basis -> List 5 words that start with "ca". -cadaver, cauliflower, cabbage (vegetable), catalpa (tree) and Cailleach. -> -``` - ### Obtaining and using the Facebook LLaMA 2 model - Refer to [Facebook's LLaMA download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) if you want to access the model data. diff --git a/examples/alpaca.sh b/examples/alpaca.sh deleted file mode 100755 index 8d2bae691..000000000 --- a/examples/alpaca.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# -# Temporary script - will be removed in the future -# - -cd `dirname $0` -cd .. - -./main -m ./models/alpaca.13b.ggmlv3.q8_0.bin \ - --color \ - -f ./prompts/alpaca.txt \ - --ctx_size 2048 \ - -n -1 \ - -ins -b 256 \ - --top_k 10000 \ - --temp 0.2 \ - --repeat_penalty 1.1 \ - -t 7 diff --git a/examples/gpt4all.sh b/examples/gpt4all.sh deleted file mode 100755 index 5fd739e55..000000000 --- a/examples/gpt4all.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -# -# Temporary script - will be removed in the future -# - -cd `dirname $0` -cd .. - -./main --color --instruct --threads 4 \ - --model ./models/gpt4all-7B/gpt4all-lora-quantized.bin \ - --file ./prompts/alpaca.txt \ - --batch_size 8 --ctx_size 2048 -n -1 \ - --repeat_last_n 64 --repeat_penalty 1.3 \ - --n_predict 128 --temp 0.1 --top_k 40 --top_p 0.95 diff --git a/examples/llama2-13b.sh b/examples/llama2-13b.sh deleted file mode 100755 index 92b3f6dd8..000000000 --- a/examples/llama2-13b.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# -# Temporary script - will be removed in the future -# - -cd `dirname $0` -cd .. - -./main -m models/available/Llama2/13B/llama-2-13b.ggmlv3.q4_0.bin \ - --color \ - --ctx_size 2048 \ - -n -1 \ - -ins -b 256 \ - --top_k 10000 \ - --temp 0.2 \ - --repeat_penalty 1.1 \ - -t 8 diff --git a/examples/llama2.sh b/examples/llama2.sh deleted file mode 100755 index 221b37553..000000000 --- a/examples/llama2.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# -# Temporary script - will be removed in the future -# - -cd `dirname $0` -cd .. - -./main -m models/available/Llama2/7B/llama-2-7b.ggmlv3.q4_0.bin \ - --color \ - --ctx_size 2048 \ - -n -1 \ - -ins -b 256 \ - --top_k 10000 \ - --temp 0.2 \ - --repeat_penalty 1.1 \ - -t 8 From fd5ea0f897ecb3659d6c269ef6f3d833e865ead7 Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 10 Jun 2024 14:18:41 +0200 Subject: [PATCH 15/15] ci : try win-2019 on server windows test (#7854) --- .github/workflows/server.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 0789efd18..0d16ef5f1 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -16,11 +16,9 @@ on: branches: - master paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*'] - pull_request_target: + pull_request: types: [opened, synchronize, reopened] paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*'] - schedule: - - cron: '2 4 * * *' concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }} @@ -115,7 +113,7 @@ jobs: server-windows: - runs-on: windows-latest + runs-on: windows-2019 steps: - name: Clone