mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 12:21:40 +01:00
Merge branch 'ggerganov:master' into support_glm_edge_model
This commit is contained in:
commit
bc93d2a44e
26
.github/workflows/server.yml
vendored
26
.github/workflows/server.yml
vendored
@ -76,20 +76,26 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install -r examples/server/tests/requirements.txt
|
pip install -r examples/server/tests/requirements.txt
|
||||||
|
|
||||||
- name: Verify server deps
|
# Setup nodejs (to be used for verifying bundled index.html)
|
||||||
id: verify_server_deps
|
- uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 22
|
||||||
|
|
||||||
|
- name: Verify bundled index.html
|
||||||
|
id: verify_server_index_html
|
||||||
run: |
|
run: |
|
||||||
git config --global --add safe.directory $(realpath .)
|
git config --global --add safe.directory $(realpath .)
|
||||||
cd examples/server
|
cd examples/server/webui
|
||||||
git ls-files --others --modified
|
|
||||||
git status
|
git status
|
||||||
./deps.sh
|
npm ci
|
||||||
|
npm run build
|
||||||
git status
|
git status
|
||||||
not_ignored_files="$(git ls-files --others --modified)"
|
modified_files="$(git status -s)"
|
||||||
echo "Modified files: ${not_ignored_files}"
|
echo "Modified files: ${modified_files}"
|
||||||
if [ -n "${not_ignored_files}" ]; then
|
if [ -n "${modified_files}" ]; then
|
||||||
echo "Repository is dirty or server deps are not built as expected"
|
echo "Repository is dirty or server/webui is not built as expected"
|
||||||
echo "${not_ignored_files}"
|
echo "Hint: You may need to follow Web UI build guide in server/README.md"
|
||||||
|
echo "${modified_files}"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -104,6 +104,10 @@ examples/server/*.mjs.hpp
|
|||||||
!examples/sycl/*.bat
|
!examples/sycl/*.bat
|
||||||
!examples/sycl/*.sh
|
!examples/sycl/*.sh
|
||||||
|
|
||||||
|
# Server Web UI temporary files
|
||||||
|
node_modules
|
||||||
|
examples/server/webui/dist
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
|
|
||||||
/.venv
|
/.venv
|
||||||
|
19
Makefile
19
Makefile
@ -1145,8 +1145,15 @@ $(LIB_COMMON_S): $(OBJ_COMMON)
|
|||||||
# Include dependency files
|
# Include dependency files
|
||||||
-include $(DEP_FILES)
|
-include $(DEP_FILES)
|
||||||
|
|
||||||
|
# Clean generated server assets
|
||||||
|
clean-server-assets:
|
||||||
|
find examples/server -type f -name "*.js.hpp" -delete
|
||||||
|
find examples/server -type f -name "*.mjs.hpp" -delete
|
||||||
|
find examples/server -type f -name "*.css.hpp" -delete
|
||||||
|
find examples/server -type f -name "*.html.hpp" -delete
|
||||||
|
|
||||||
# Clean rule
|
# Clean rule
|
||||||
clean:
|
clean: clean-server-assets
|
||||||
rm -vrf $(BUILD_TARGETS) $(TEST_TARGETS)
|
rm -vrf $(BUILD_TARGETS) $(TEST_TARGETS)
|
||||||
rm -rvf *.a *.dll *.so *.dot
|
rm -rvf *.a *.dll *.so *.dot
|
||||||
find ggml src common tests examples pocs -type f -name "*.o" -delete
|
find ggml src common tests examples pocs -type f -name "*.o" -delete
|
||||||
@ -1354,20 +1361,14 @@ llama-server: \
|
|||||||
examples/server/utils.hpp \
|
examples/server/utils.hpp \
|
||||||
examples/server/httplib.h \
|
examples/server/httplib.h \
|
||||||
examples/server/index.html.hpp \
|
examples/server/index.html.hpp \
|
||||||
examples/server/completion.js.hpp \
|
|
||||||
examples/server/loading.html.hpp \
|
examples/server/loading.html.hpp \
|
||||||
examples/server/deps_daisyui.min.css.hpp \
|
|
||||||
examples/server/deps_markdown-it.js.hpp \
|
|
||||||
examples/server/deps_tailwindcss.js.hpp \
|
|
||||||
examples/server/deps_vue.esm-browser.js.hpp \
|
|
||||||
common/json.hpp \
|
common/json.hpp \
|
||||||
common/stb_image.h \
|
|
||||||
$(OBJ_ALL)
|
$(OBJ_ALL)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||||
|
|
||||||
# Portable equivalent of `cd examples/server/public && xxd -i $(notdir $<) ../$(notdir $<).hpp`:
|
# Portable equivalent of `cd examples/server/public && xxd -i $(notdir $<) ../$(notdir $<).hpp`:
|
||||||
examples/server/%.hpp: examples/server/public/% Makefile
|
examples/server/%.hpp: examples/server/public/% FORCE Makefile
|
||||||
@( export NAME=$(subst .,_,$(subst -,_,$(notdir $<))) && \
|
@( export NAME=$(subst .,_,$(subst -,_,$(notdir $<))) && \
|
||||||
echo "unsigned char $${NAME}[] = {" && \
|
echo "unsigned char $${NAME}[] = {" && \
|
||||||
cat $< | od -v -t x1 -An | sed -E 's/([0-9a-fA-F]+)/0x\1, /g' && \
|
cat $< | od -v -t x1 -An | sed -E 's/([0-9a-fA-F]+)/0x\1, /g' && \
|
||||||
@ -1542,7 +1543,7 @@ llama-q8dot: pocs/vdot/q8dot.cpp ggml/src/ggml.o \
|
|||||||
# Deprecated binaries that we want to keep around long enough for people to migrate to the new filenames, then these can be removed.
|
# Deprecated binaries that we want to keep around long enough for people to migrate to the new filenames, then these can be removed.
|
||||||
#
|
#
|
||||||
# Mark legacy binary targets as .PHONY so that they are always checked.
|
# Mark legacy binary targets as .PHONY so that they are always checked.
|
||||||
.PHONY: main quantize perplexity embedding server
|
.PHONY: FORCE main quantize perplexity embedding server
|
||||||
|
|
||||||
# Define the object file target
|
# Define the object file target
|
||||||
examples/deprecation-warning/deprecation-warning.o: examples/deprecation-warning/deprecation-warning.cpp
|
examples/deprecation-warning/deprecation-warning.o: examples/deprecation-warning/deprecation-warning.cpp
|
||||||
|
@ -39,6 +39,11 @@ cmake --build build --config Release
|
|||||||
```
|
```
|
||||||
|
|
||||||
For more details and a list of supported generators, see the [CMake documentation](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html).
|
For more details and a list of supported generators, see the [CMake documentation](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html).
|
||||||
|
- For static builds, add `-DBUILD_SHARED_LIBS=OFF`:
|
||||||
|
```
|
||||||
|
cmake -B build -DBUILD_SHARED_LIBS=OFF
|
||||||
|
cmake --build build --config Release
|
||||||
|
```
|
||||||
|
|
||||||
- Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers:
|
- Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers:
|
||||||
- Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/de/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...):
|
- Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/de/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...):
|
||||||
|
@ -14,7 +14,7 @@ In this section, we cover the most commonly used options for running the `infill
|
|||||||
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
|
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
|
||||||
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
|
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
|
||||||
- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
|
- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
|
||||||
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
|
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 4096, but if a LLaMA model was built with a longer context, increasing this value will provide better results for longer input/inference.
|
||||||
- `--spm-infill`: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
|
- `--spm-infill`: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
|
||||||
|
|
||||||
## Input Prompts
|
## Input Prompts
|
||||||
|
@ -12,6 +12,10 @@
|
|||||||
#include "ggml-cuda.h"
|
#include "ggml-cuda.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_SYCL
|
||||||
|
#include "ggml-sycl.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
#include "ggml-metal.h"
|
#include "ggml-metal.h"
|
||||||
#endif
|
#endif
|
||||||
@ -1215,6 +1219,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||||||
LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
|
LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_SYCL
|
||||||
|
new_clip->backend = ggml_backend_sycl_init(0);
|
||||||
|
LOG_INF("%s: CLIP using SYCL backend\n", __func__);
|
||||||
|
#endif
|
||||||
|
|
||||||
if (!new_clip->backend) {
|
if (!new_clip->backend) {
|
||||||
new_clip->backend = ggml_backend_cpu_init();
|
new_clip->backend = ggml_backend_cpu_init();
|
||||||
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
LOG_INF("%s: CLIP using CPU backend\n", __func__);
|
||||||
|
@ -66,7 +66,7 @@ In this section, we cover the most commonly used options for running the `llama-
|
|||||||
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g [https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true](https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true)).
|
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g [https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true](https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true)).
|
||||||
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
|
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
|
||||||
- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
|
- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
|
||||||
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
|
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 4096, but if a LLaMA model was built with a longer context, increasing this value will provide better results for longer input/inference.
|
||||||
- `-mli, --multiline-input`: Allows you to write or paste multiple lines without ending each in '\'
|
- `-mli, --multiline-input`: Allows you to write or paste multiple lines without ending each in '\'
|
||||||
- `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has.
|
- `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has.
|
||||||
- `-ngl N, --n-gpu-layers N`: When compiled with GPU support, this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
|
- `-ngl N, --n-gpu-layers N`: When compiled with GPU support, this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
|
||||||
@ -131,7 +131,7 @@ During text generation, LLaMA models have a limited context size, which means th
|
|||||||
|
|
||||||
### Context Size
|
### Context Size
|
||||||
|
|
||||||
- `-c N, --ctx-size N`: Set the size of the prompt context (default: 0, 0 = loaded from model). The LLaMA models were built with a context of 2048-8192, which will yield the best results on longer input/inference.
|
- `-c N, --ctx-size N`: Set the size of the prompt context (default: 4096, 0 = loaded from model). If a LLaMA model was built with a longer context, increasing this value will yield the best results on longer input/inference.
|
||||||
|
|
||||||
### Extended Context Size
|
### Extended Context Size
|
||||||
|
|
||||||
@ -348,6 +348,7 @@ These options provide extra functionality and customization when running the LLa
|
|||||||
|
|
||||||
- `-h, --help`: Display a help message showing all available options and their default values. This is particularly useful for checking the latest options and default values, as they can change frequently, and the information in this document may become outdated.
|
- `-h, --help`: Display a help message showing all available options and their default values. This is particularly useful for checking the latest options and default values, as they can change frequently, and the information in this document may become outdated.
|
||||||
- `--verbose-prompt`: Print the prompt before generating text.
|
- `--verbose-prompt`: Print the prompt before generating text.
|
||||||
|
- `--no-display-prompt`: Don't print prompt at generation.
|
||||||
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
|
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
|
||||||
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance.
|
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance.
|
||||||
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
|
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
|
||||||
|
@ -16,12 +16,7 @@ set(TARGET_SRCS
|
|||||||
)
|
)
|
||||||
set(PUBLIC_ASSETS
|
set(PUBLIC_ASSETS
|
||||||
index.html
|
index.html
|
||||||
completion.js
|
|
||||||
loading.html
|
loading.html
|
||||||
deps_daisyui.min.css
|
|
||||||
deps_markdown-it.js
|
|
||||||
deps_tailwindcss.js
|
|
||||||
deps_vue.esm-browser.js
|
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(asset ${PUBLIC_ASSETS})
|
foreach(asset ${PUBLIC_ASSETS})
|
||||||
@ -33,11 +28,20 @@ foreach(asset ${PUBLIC_ASSETS})
|
|||||||
OUTPUT "${output}"
|
OUTPUT "${output}"
|
||||||
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
|
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
|
||||||
)
|
)
|
||||||
|
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
add_executable(${TARGET} ${TARGET_SRCS})
|
add_executable(${TARGET} ${TARGET_SRCS})
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
|
||||||
|
# clean up generated files in pre-build step
|
||||||
|
foreach(asset ${PUBLIC_ASSETS})
|
||||||
|
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
|
||||||
|
add_custom_command(TARGET ${TARGET} PRE_BUILD
|
||||||
|
COMMAND "${CMAKE_COMMAND}" -E remove -f "${output}"
|
||||||
|
)
|
||||||
|
endforeach()
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
if (LLAMA_SERVER_SSL)
|
if (LLAMA_SERVER_SSL)
|
||||||
|
@ -217,6 +217,37 @@ services:
|
|||||||
cmake --build build --config Release -t llama-server
|
cmake --build build --config Release -t llama-server
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Web UI
|
||||||
|
|
||||||
|
The project includes a web-based user interface that enables interaction with the model through the `/chat/completions` endpoint.
|
||||||
|
|
||||||
|
The web UI is developed using:
|
||||||
|
- `vue` framework for frontend development
|
||||||
|
- `tailwindcss` and `daisyui` for styling
|
||||||
|
- `vite` for build tooling
|
||||||
|
|
||||||
|
A pre-built version is available as a single HTML file under `/public` directory.
|
||||||
|
|
||||||
|
To build or to run the dev server (with hot reload):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# make sure you have nodejs installed
|
||||||
|
cd examples/server/webui
|
||||||
|
npm i
|
||||||
|
|
||||||
|
# to run the dev server
|
||||||
|
npm run dev
|
||||||
|
|
||||||
|
# to build the public/index.html
|
||||||
|
npm run build
|
||||||
|
```
|
||||||
|
|
||||||
|
NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console:
|
||||||
|
|
||||||
|
```js
|
||||||
|
localStorage.setItem('base', 'http://localhost:8080')
|
||||||
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
To get started right away, run the following command, making sure to use the correct path for the model you have:
|
To get started right away, run the following command, making sure to use the correct path for the model you have:
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Download and update deps for binary
|
|
||||||
|
|
||||||
# get the directory of this script file
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
PUBLIC=$DIR/public
|
|
||||||
|
|
||||||
echo "download js bundle files"
|
|
||||||
|
|
||||||
# Note for contributors: Always pin to a specific version "maj.min.patch" to avoid breaking the CI
|
|
||||||
|
|
||||||
curl -L https://cdn.tailwindcss.com/3.4.14 > $PUBLIC/deps_tailwindcss.js
|
|
||||||
echo >> $PUBLIC/deps_tailwindcss.js # add newline
|
|
||||||
|
|
||||||
curl -L https://cdnjs.cloudflare.com/ajax/libs/daisyui/4.12.14/styled.min.css > $PUBLIC/deps_daisyui.min.css
|
|
||||||
curl -L https://cdnjs.cloudflare.com/ajax/libs/daisyui/4.12.14/themes.min.css >> $PUBLIC/deps_daisyui.min.css
|
|
||||||
echo >> $PUBLIC/deps_daisyui.min.css # add newline
|
|
||||||
|
|
||||||
curl -L https://unpkg.com/vue@3.5.12/dist/vue.esm-browser.js > $PUBLIC/deps_vue.esm-browser.js
|
|
||||||
echo >> $PUBLIC/deps_vue.esm-browser.js # add newline
|
|
||||||
|
|
||||||
curl -L https://cdnjs.cloudflare.com/ajax/libs/markdown-it/13.0.2/markdown-it.js > $PUBLIC/deps_markdown-it.js
|
|
||||||
echo >> $PUBLIC/deps_markdown-it.js # add newline
|
|
||||||
|
|
||||||
ls -lah $PUBLIC
|
|
13
examples/server/public/deps_daisyui.min.css
vendored
13
examples/server/public/deps_daisyui.min.css
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -16,12 +16,7 @@
|
|||||||
|
|
||||||
// auto generated files (update with ./deps.sh)
|
// auto generated files (update with ./deps.sh)
|
||||||
#include "index.html.hpp"
|
#include "index.html.hpp"
|
||||||
#include "completion.js.hpp"
|
|
||||||
#include "loading.html.hpp"
|
#include "loading.html.hpp"
|
||||||
#include "deps_daisyui.min.css.hpp"
|
|
||||||
#include "deps_markdown-it.js.hpp"
|
|
||||||
#include "deps_tailwindcss.js.hpp"
|
|
||||||
#include "deps_vue.esm-browser.js.hpp"
|
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
@ -103,12 +98,6 @@ struct server_task_result {
|
|||||||
bool error;
|
bool error;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_static_file {
|
|
||||||
const unsigned char * data;
|
|
||||||
unsigned int size;
|
|
||||||
const char * mime_type;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct slot_params {
|
struct slot_params {
|
||||||
bool stream = true;
|
bool stream = true;
|
||||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||||
@ -696,8 +685,9 @@ struct server_context {
|
|||||||
|
|
||||||
params_dft.devices = params_base.speculative.devices;
|
params_dft.devices = params_base.speculative.devices;
|
||||||
params_dft.model = params_base.speculative.model;
|
params_dft.model = params_base.speculative.model;
|
||||||
params_dft.n_ctx = params_base.speculative.n_ctx;
|
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
||||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||||
|
params_dft.n_parallel = 1;
|
||||||
|
|
||||||
common_init_result llama_init_dft = common_init_from_params(params_dft);
|
common_init_result llama_init_dft = common_init_from_params(params_dft);
|
||||||
|
|
||||||
@ -717,8 +707,14 @@ struct server_context {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
cparams_dft = common_context_params_to_llama(params_base);
|
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
|
||||||
cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
|
|
||||||
|
cparams_dft = common_context_params_to_llama(params_dft);
|
||||||
|
cparams_dft.n_batch = n_ctx_dft;
|
||||||
|
|
||||||
|
// force F16 KV cache for the draft model for extra performance
|
||||||
|
cparams_dft.type_k = GGML_TYPE_F16;
|
||||||
|
cparams_dft.type_v = GGML_TYPE_F16;
|
||||||
|
|
||||||
// the context is not needed - we will create one for each slot
|
// the context is not needed - we will create one for each slot
|
||||||
llama_free(llama_init_dft.context);
|
llama_free(llama_init_dft.context);
|
||||||
@ -2322,6 +2318,10 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token id = slot.sampled;
|
llama_token id = slot.sampled;
|
||||||
|
|
||||||
struct common_speculative_params params_spec;
|
struct common_speculative_params params_spec;
|
||||||
@ -2446,16 +2446,6 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||||
LOG_INF("\n");
|
LOG_INF("\n");
|
||||||
|
|
||||||
// static files
|
|
||||||
std::map<std::string, server_static_file> static_files = {
|
|
||||||
{ "/", { index_html, index_html_len, "text/html; charset=utf-8" }},
|
|
||||||
{ "/completion.js", { completion_js, completion_js_len, "text/javascript; charset=utf-8" }},
|
|
||||||
{ "/deps_daisyui.min.css", { deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8" }},
|
|
||||||
{ "/deps_markdown-it.js", { deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8" }},
|
|
||||||
{ "/deps_tailwindcss.js", { deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8" }},
|
|
||||||
{ "/deps_vue.esm-browser.js", { deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8" }},
|
|
||||||
};
|
|
||||||
|
|
||||||
std::unique_ptr<httplib::Server> svr;
|
std::unique_ptr<httplib::Server> svr;
|
||||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||||
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||||
@ -2536,7 +2526,7 @@ int main(int argc, char ** argv) {
|
|||||||
// Middlewares
|
// Middlewares
|
||||||
//
|
//
|
||||||
|
|
||||||
auto middleware_validate_api_key = [¶ms, &res_error, &static_files](const httplib::Request & req, httplib::Response & res) {
|
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
static const std::unordered_set<std::string> public_endpoints = {
|
static const std::unordered_set<std::string> public_endpoints = {
|
||||||
"/health",
|
"/health",
|
||||||
"/models",
|
"/models",
|
||||||
@ -2549,7 +2539,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If path is public or is static file, skip validation
|
// If path is public or is static file, skip validation
|
||||||
if (public_endpoints.find(req.path) != public_endpoints.end() || static_files.find(req.path) != static_files.end()) {
|
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3306,15 +3296,12 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// using embedded static files
|
// using embedded static index.html
|
||||||
for (const auto & it : static_files) {
|
svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
|
||||||
const server_static_file & static_file = it.second;
|
res.set_content(reinterpret_cast<const char*>(index_html), index_html_len, "text/html; charset=utf-8");
|
||||||
svr->Get(it.first.c_str(), [&static_file](const httplib::Request &, httplib::Response & res) {
|
|
||||||
res.set_content(reinterpret_cast<const char*>(static_file.data), static_file.size, static_file.mime_type);
|
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// register API routes
|
// register API routes
|
||||||
svr->Get ("/health", handle_health); // public endpoint (no API key check)
|
svr->Get ("/health", handle_health); // public endpoint (no API key check)
|
||||||
|
268
examples/server/webui/index.html
Normal file
268
examples/server/webui/index.html
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1" />
|
||||||
|
<meta name="color-scheme" content="light dark">
|
||||||
|
<title>🦙 llama.cpp - chat</title>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div id="app" class="opacity-0"> <!-- opacity-0 will be removed on app mounted -->
|
||||||
|
<div class="flex flex-row drawer lg:drawer-open">
|
||||||
|
<input id="toggle-drawer" type="checkbox" class="drawer-toggle" checked />
|
||||||
|
|
||||||
|
<!-- sidebar -->
|
||||||
|
<div class="drawer-side h-screen lg:h-screen z-50 lg:max-w-64">
|
||||||
|
<label for="toggle-drawer" aria-label="close sidebar" class="drawer-overlay"></label>
|
||||||
|
<div class="flex flex-col bg-base-200 min-h-full max-w-[calc(100vw-2em)] py-4 px-4">
|
||||||
|
<div class="flex flex-row items-center justify-between mb-4 mt-4">
|
||||||
|
<h2 class="font-bold ml-4">Conversations</h2>
|
||||||
|
|
||||||
|
<!-- close sidebar button -->
|
||||||
|
<label for="toggle-drawer" class="btn btn-ghost lg:hidden">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-arrow-bar-left" viewBox="0 0 16 16">
|
||||||
|
<path fill-rule="evenodd" d="M12.5 15a.5.5 0 0 1-.5-.5v-13a.5.5 0 0 1 1 0v13a.5.5 0 0 1-.5.5M10 8a.5.5 0 0 1-.5.5H3.707l2.147 2.146a.5.5 0 0 1-.708.708l-3-3a.5.5 0 0 1 0-.708l3-3a.5.5 0 1 1 .708.708L3.707 7.5H9.5a.5.5 0 0 1 .5.5"/>
|
||||||
|
</svg>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- list of conversations -->
|
||||||
|
<div :class="{
|
||||||
|
'btn btn-ghost justify-start': true,
|
||||||
|
'btn-active': messages.length === 0,
|
||||||
|
}" @click="newConversation">
|
||||||
|
+ New conversation
|
||||||
|
</div>
|
||||||
|
<div v-for="conv in conversations" :class="{
|
||||||
|
'btn btn-ghost justify-start font-normal': true,
|
||||||
|
'btn-active': conv.id === viewingConvId,
|
||||||
|
}" @click="setViewingConv(conv.id)">
|
||||||
|
<span class="truncate">{{ conv.messages[0].content }}</span>
|
||||||
|
</div>
|
||||||
|
<div class="text-center text-xs opacity-40 mt-auto mx-4">
|
||||||
|
Conversations are saved to browser's localStorage
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- main view -->
|
||||||
|
<div class="chat-screen drawer-content grow flex flex-col h-screen w-screen mx-auto px-4">
|
||||||
|
<!-- header -->
|
||||||
|
<div class="flex flex-row items-center mt-6 mb-6">
|
||||||
|
<!-- open sidebar button -->
|
||||||
|
<label for="toggle-drawer" class="btn btn-ghost lg:hidden">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-list" viewBox="0 0 16 16">
|
||||||
|
<path fill-rule="evenodd" d="M2.5 12a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 0 1H3a.5.5 0 0 1-.5-.5m0-4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 0 1H3a.5.5 0 0 1-.5-.5m0-4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 0 1H3a.5.5 0 0 1-.5-.5"/>
|
||||||
|
</svg>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<div class="grow text-2xl font-bold ml-2">llama.cpp</div>
|
||||||
|
|
||||||
|
<!-- action buttons (top right) -->
|
||||||
|
<div class="flex items-center">
|
||||||
|
<div v-if="messages.length > 0" class="dropdown dropdown-end">
|
||||||
|
<!-- "more" button -->
|
||||||
|
<button tabindex="0" role="button" class="btn m-1" :disabled="isGenerating">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-three-dots-vertical" viewBox="0 0 16 16">
|
||||||
|
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
<!-- "more" dropdown menu -->
|
||||||
|
<ul tabindex="0" class="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
|
||||||
|
<li @click="downloadConv(viewingConvId)"><a>Download</a></li>
|
||||||
|
<li class="text-error" @click="deleteConv(viewingConvId)"><a>Delete</a></li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
<button class="btn" @click="showConfigDialog = true" :disabled="isGenerating">
|
||||||
|
<!-- settings button -->
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-gear" viewBox="0 0 16 16">
|
||||||
|
<path d="M8 4.754a3.246 3.246 0 1 0 0 6.492 3.246 3.246 0 0 0 0-6.492M5.754 8a2.246 2.246 0 1 1 4.492 0 2.246 2.246 0 0 1-4.492 0"/>
|
||||||
|
<path d="M9.796 1.343c-.527-1.79-3.065-1.79-3.592 0l-.094.319a.873.873 0 0 1-1.255.52l-.292-.16c-1.64-.892-3.433.902-2.54 2.541l.159.292a.873.873 0 0 1-.52 1.255l-.319.094c-1.79.527-1.79 3.065 0 3.592l.319.094a.873.873 0 0 1 .52 1.255l-.16.292c-.892 1.64.901 3.434 2.541 2.54l.292-.159a.873.873 0 0 1 1.255.52l.094.319c.527 1.79 3.065 1.79 3.592 0l.094-.319a.873.873 0 0 1 1.255-.52l.292.16c1.64.893 3.434-.902 2.54-2.541l-.159-.292a.873.873 0 0 1 .52-1.255l.319-.094c1.79-.527 1.79-3.065 0-3.592l-.319-.094a.873.873 0 0 1-.52-1.255l.16-.292c.893-1.64-.902-3.433-2.541-2.54l-.292.159a.873.873 0 0 1-1.255-.52zm-2.633.283c.246-.835 1.428-.835 1.674 0l.094.319a1.873 1.873 0 0 0 2.693 1.115l.291-.16c.764-.415 1.6.42 1.184 1.185l-.159.292a1.873 1.873 0 0 0 1.116 2.692l.318.094c.835.246.835 1.428 0 1.674l-.319.094a1.873 1.873 0 0 0-1.115 2.693l.16.291c.415.764-.42 1.6-1.185 1.184l-.291-.159a1.873 1.873 0 0 0-2.693 1.116l-.094.318c-.246.835-1.428.835-1.674 0l-.094-.319a1.873 1.873 0 0 0-2.692-1.115l-.292.16c-.764.415-1.6-.42-1.184-1.185l.159-.291A1.873 1.873 0 0 0 1.945 8.93l-.319-.094c-.835-.246-.835-1.428 0-1.674l.319-.094A1.873 1.873 0 0 0 3.06 4.377l-.16-.292c-.415-.764.42-1.6 1.185-1.184l.292.159a1.873 1.873 0 0 0 2.692-1.115z"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<!-- theme controller is copied from https://daisyui.com/components/theme-controller/ -->
|
||||||
|
<div class="dropdown dropdown-end dropdown-bottom">
|
||||||
|
<div tabindex="0" role="button" class="btn m-1">
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-palette2" viewBox="0 0 16 16">
|
||||||
|
<path d="M0 .5A.5.5 0 0 1 .5 0h5a.5.5 0 0 1 .5.5v5.277l4.147-4.131a.5.5 0 0 1 .707 0l3.535 3.536a.5.5 0 0 1 0 .708L10.261 10H15.5a.5.5 0 0 1 .5.5v5a.5.5 0 0 1-.5.5H3a3 3 0 0 1-2.121-.879A3 3 0 0 1 0 13.044m6-.21 7.328-7.3-2.829-2.828L6 7.188zM4.5 13a1.5 1.5 0 1 0-3 0 1.5 1.5 0 0 0 3 0M15 15v-4H9.258l-4.015 4zM0 .5v12.495zm0 12.495V13z"/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<ul tabindex="0" class="dropdown-content bg-base-300 rounded-box z-[1] w-52 p-2 shadow-2xl h-80 overflow-y-auto">
|
||||||
|
<li>
|
||||||
|
<button
|
||||||
|
class="btn btn-sm btn-block btn-ghost justify-start"
|
||||||
|
:class="{ 'btn-active': selectedTheme === 'auto' }"
|
||||||
|
@click="setSelectedTheme('auto')">
|
||||||
|
auto
|
||||||
|
</button>
|
||||||
|
</li>
|
||||||
|
<li v-for="theme in themes">
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="theme-dropdown"
|
||||||
|
class="theme-controller btn btn-sm btn-block btn-ghost justify-start"
|
||||||
|
:aria-label="theme"
|
||||||
|
:value="theme"
|
||||||
|
:checked="selectedTheme === theme"
|
||||||
|
@click="setSelectedTheme(theme)" />
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- chat messages -->
|
||||||
|
<div id="messages-list" class="flex flex-col grow overflow-y-auto">
|
||||||
|
<div class="mt-auto flex justify-center">
|
||||||
|
<!-- placeholder to shift the message to the bottom -->
|
||||||
|
{{ messages.length === 0 ? 'Send a message to start' : '' }}
|
||||||
|
</div>
|
||||||
|
<div v-for="msg in messages" class="group">
|
||||||
|
<div :class="{
|
||||||
|
'chat': true,
|
||||||
|
'chat-start': msg.role !== 'user',
|
||||||
|
'chat-end': msg.role === 'user',
|
||||||
|
}">
|
||||||
|
<div :class="{
|
||||||
|
'chat-bubble markdown': true,
|
||||||
|
'chat-bubble-base-300': msg.role !== 'user',
|
||||||
|
}">
|
||||||
|
<!-- textarea for editing message -->
|
||||||
|
<template v-if="editingMsg && editingMsg.id === msg.id">
|
||||||
|
<textarea
|
||||||
|
class="textarea textarea-bordered bg-base-100 text-base-content w-[calc(90vw-8em)] lg:w-96"
|
||||||
|
v-model="msg.content"></textarea>
|
||||||
|
<br/>
|
||||||
|
<button class="btn btn-ghost mt-2 mr-2" @click="editingMsg = null">Cancel</button>
|
||||||
|
<button class="btn mt-2" @click="editUserMsgAndRegenerate(msg)">Submit</button>
|
||||||
|
</template>
|
||||||
|
<!-- render message as markdown -->
|
||||||
|
<vue-markdown v-else :source="msg.content" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- actions for each message -->
|
||||||
|
<div :class="{'text-right': msg.role === 'user'}" class="mx-4 mt-2 mb-2">
|
||||||
|
<!-- user message -->
|
||||||
|
<button v-if="msg.role === 'user'" class="badge btn-mini show-on-hover" @click="editingMsg = msg" :disabled="isGenerating">
|
||||||
|
✍️ Edit
|
||||||
|
</button>
|
||||||
|
<!-- assistant message -->
|
||||||
|
<button v-if="msg.role === 'assistant'" class="badge btn-mini show-on-hover mr-2" @click="regenerateMsg(msg)" :disabled="isGenerating">
|
||||||
|
🔄 Regenerate
|
||||||
|
</button>
|
||||||
|
<button v-if="msg.role === 'assistant'" class="badge btn-mini show-on-hover mr-2" @click="copyMsg(msg)" :disabled="isGenerating">
|
||||||
|
📋 Copy
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- pending (ongoing) assistant message -->
|
||||||
|
<div id="pending-msg" class="chat chat-start">
|
||||||
|
<div v-if="pendingMsg" class="chat-bubble markdown chat-bubble-base-300">
|
||||||
|
<span v-if="!pendingMsg.content" class="loading loading-dots loading-md"></span>
|
||||||
|
<vue-markdown v-else :source="pendingMsg.content" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- chat input -->
|
||||||
|
<div class="flex flex-row items-center mt-8 mb-6">
|
||||||
|
<textarea
|
||||||
|
class="textarea textarea-bordered w-full"
|
||||||
|
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||||
|
v-model="inputMsg"
|
||||||
|
@keydown.enter.exact.prevent="sendMessage"
|
||||||
|
@keydown.enter.shift.exact.prevent="inputMsg += '\n'"
|
||||||
|
:disabled="isGenerating"
|
||||||
|
id="msg-input"
|
||||||
|
></textarea>
|
||||||
|
<button v-if="!isGenerating" class="btn btn-primary ml-2" @click="sendMessage" :disabled="inputMsg.length === 0">Send</button>
|
||||||
|
<button v-else class="btn btn-neutral ml-2" @click="stopGeneration">Stop</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<!-- modal for editing config -->
|
||||||
|
<dialog class="modal" :class="{'modal-open': showConfigDialog}">
|
||||||
|
<div class="modal-box">
|
||||||
|
<h3 class="text-lg font-bold mb-6">Settings</h3>
|
||||||
|
<div class="h-[calc(90vh-12rem)] overflow-y-auto">
|
||||||
|
<p class="opacity-40 mb-6">Settings below are saved in browser's localStorage</p>
|
||||||
|
<settings-modal-short-input :config-key="'apiKey'" :config-default="configDefault" :config-info="configInfo" v-model="config.apiKey"></settings-modal-short-input>
|
||||||
|
<label class="form-control mb-2">
|
||||||
|
<div class="label">System Message</div>
|
||||||
|
<textarea class="textarea textarea-bordered h-24" :placeholder="'Default: ' + configDefault.systemMessage" v-model="config.systemMessage"></textarea>
|
||||||
|
</label>
|
||||||
|
<template v-for="configKey in ['temperature', 'top_k', 'top_p', 'min_p', 'max_tokens']">
|
||||||
|
<settings-modal-short-input :config-key="configKey" :config-default="configDefault" :config-info="configInfo" v-model="config[configKey]"></settings-modal-short-input>
|
||||||
|
</template>
|
||||||
|
<!-- TODO: add more sampling-related configs, please regroup them into different "collapse" sections -->
|
||||||
|
<!-- Section: Other sampler settings -->
|
||||||
|
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||||
|
<summary class="collapse-title font-bold">Other sampler settings</summary>
|
||||||
|
<div class="collapse-content">
|
||||||
|
<!-- Samplers queue -->
|
||||||
|
<settings-modal-short-input label="Samplers queue" :config-key="'samplers'" :config-default="configDefault" :config-info="configInfo" v-model="config.samplers"></settings-modal-short-input>
|
||||||
|
<!-- Samplers -->
|
||||||
|
<template v-for="configKey in ['dynatemp_range', 'dynatemp_exponent', 'typical_p', 'xtc_probability', 'xtc_threshold']">
|
||||||
|
<settings-modal-short-input :config-key="configKey" :config-default="configDefault" :config-info="configInfo" v-model="config[configKey]"></settings-modal-short-input>
|
||||||
|
</template>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
<!-- Section: Penalties settings -->
|
||||||
|
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||||
|
<summary class="collapse-title font-bold">Penalties settings</summary>
|
||||||
|
<div class="collapse-content">
|
||||||
|
<template v-for="configKey in ['repeat_last_n', 'repeat_penalty', 'presence_penalty', 'frequency_penalty', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_penalty_last_n']">
|
||||||
|
<settings-modal-short-input :config-key="configKey" :config-default="configDefault" :config-info="configInfo" v-model="config[configKey]"></settings-modal-short-input>
|
||||||
|
</template>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
<!-- Section: Advanced config -->
|
||||||
|
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||||
|
<summary class="collapse-title font-bold">Advanced config</summary>
|
||||||
|
<div class="collapse-content">
|
||||||
|
<label class="form-control mb-2">
|
||||||
|
<!-- Custom parameters input -->
|
||||||
|
<div class="label inline">Custom JSON config (For more info, refer to <a class="underline" href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md" target="_blank" rel="noopener noreferrer">server documentation</a>)</div>
|
||||||
|
<textarea class="textarea textarea-bordered h-24" placeholder="Example: { "mirostat": 1, "min_p": 0.1 }" v-model="config.custom"></textarea>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- action buttons -->
|
||||||
|
<div class="modal-action">
|
||||||
|
<button class="btn" @click="resetConfigDialog">Reset to default</button>
|
||||||
|
<button class="btn" @click="closeAndDiscardConfigDialog">Close</button>
|
||||||
|
<button class="btn btn-primary" @click="closeAndSaveConfigDialog">Save</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</dialog>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Template to be used by settings modal -->
|
||||||
|
<template id="settings-modal-short-input">
|
||||||
|
<label class="input input-bordered join-item grow flex items-center gap-2 mb-2">
|
||||||
|
<!-- Show help message on hovering on the input label -->
|
||||||
|
<div class="dropdown dropdown-hover">
|
||||||
|
<div tabindex="0" role="button" class="font-bold">{{ label || configKey }}</div>
|
||||||
|
<div class="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
|
||||||
|
{{ configInfo[configKey] || '(no help message available)' }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<!-- Here we forward v-model from parent to child component, see: https://stackoverflow.com/questions/47311936/v-model-and-child-components -->
|
||||||
|
<input type="text" class="grow" :placeholder="'Default: ' + (configDefault[configKey] || 'none')" :value="modelValue" @input="$emit('update:modelValue', $event.target.value)" />
|
||||||
|
</label>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script type="module" src="/src/main.js"></script>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
2783
examples/server/webui/package-lock.json
generated
Normal file
2783
examples/server/webui/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
23
examples/server/webui/package.json
Normal file
23
examples/server/webui/package.json
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"name": "webui",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"vite": "^5.4.10"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"autoprefixer": "^10.4.20",
|
||||||
|
"daisyui": "^4.12.14",
|
||||||
|
"markdown-it": "^14.1.0",
|
||||||
|
"postcss": "^8.4.49",
|
||||||
|
"tailwindcss": "^3.4.15",
|
||||||
|
"vite-plugin-singlefile": "^2.0.3",
|
||||||
|
"vue": "^3.5.13"
|
||||||
|
}
|
||||||
|
}
|
6
examples/server/webui/postcss.config.js
Normal file
6
examples/server/webui/postcss.config.js
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
export default {
|
||||||
|
plugins: {
|
||||||
|
tailwindcss: {},
|
||||||
|
autoprefixer: {},
|
||||||
|
},
|
||||||
|
}
|
456
examples/server/webui/src/main.js
Normal file
456
examples/server/webui/src/main.js
Normal file
@ -0,0 +1,456 @@
|
|||||||
|
import './styles.css';
|
||||||
|
import { createApp, defineComponent, shallowRef, computed, h } from 'vue/dist/vue.esm-bundler.js';
|
||||||
|
import { llama } from './completion.js';
|
||||||
|
import MarkdownIt from 'markdown-it';
|
||||||
|
|
||||||
|
// utility functions
|
||||||
|
const isString = (x) => !!x.toLowerCase;
|
||||||
|
const isNumeric = (n) => !isString(n) && !isNaN(n);
|
||||||
|
const escapeAttr = (str) => str.replace(/>/g, '>').replace(/"/g, '"');
|
||||||
|
const copyStr = (str) => navigator.clipboard.writeText(str);
|
||||||
|
|
||||||
|
// constants
|
||||||
|
const BASE_URL = localStorage.getItem('base') // for debugging
|
||||||
|
|| (new URL('.', document.baseURI).href).toString(); // for production
|
||||||
|
const CONFIG_DEFAULT = {
|
||||||
|
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
|
||||||
|
apiKey: '',
|
||||||
|
systemMessage: 'You are a helpful assistant.',
|
||||||
|
// make sure these default values are in sync with `common.h`
|
||||||
|
samplers: 'dkypmxt',
|
||||||
|
temperature: 0.8,
|
||||||
|
dynatemp_range: 0.0,
|
||||||
|
dynatemp_exponent: 1.0,
|
||||||
|
top_k: 40,
|
||||||
|
top_p: 0.95,
|
||||||
|
min_p: 0.05,
|
||||||
|
xtc_probability: 0.0,
|
||||||
|
xtc_threshold: 0.1,
|
||||||
|
typical_p: 1.0,
|
||||||
|
repeat_last_n: 64,
|
||||||
|
repeat_penalty: 1.0,
|
||||||
|
presence_penalty: 0.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
dry_multiplier: 0.0,
|
||||||
|
dry_base: 1.75,
|
||||||
|
dry_allowed_length: 2,
|
||||||
|
dry_penalty_last_n: -1,
|
||||||
|
max_tokens: -1,
|
||||||
|
custom: '', // custom json-stringified object
|
||||||
|
};
|
||||||
|
const CONFIG_INFO = {
|
||||||
|
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||||
|
systemMessage: 'The starting message that defines how model should behave.',
|
||||||
|
samplers: 'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
|
||||||
|
temperature: 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.',
|
||||||
|
dynatemp_range: 'Addon for the temperature sampler. The added value to the range of dynamic temperature, which adjusts probabilities by entropy of tokens.',
|
||||||
|
dynatemp_exponent: 'Addon for the temperature sampler. Smoothes out the probability redistribution based on the most probable token.',
|
||||||
|
top_k: 'Keeps only k top tokens.',
|
||||||
|
top_p: 'Limits tokens to those that together have a cumulative probability of at least p',
|
||||||
|
min_p: 'Limits tokens based on the minimum probability for a token to be considered, relative to the probability of the most likely token.',
|
||||||
|
xtc_probability: 'XTC sampler cuts out top tokens; this parameter controls the chance of cutting tokens at all. 0 disables XTC.',
|
||||||
|
xtc_threshold: 'XTC sampler cuts out top tokens; this parameter controls the token probability that is required to cut that token.',
|
||||||
|
typical_p: 'Sorts and limits tokens based on the difference between log-probability and entropy.',
|
||||||
|
repeat_last_n: 'Last n tokens to consider for penalizing repetition',
|
||||||
|
repeat_penalty: 'Controls the repetition of token sequences in the generated text',
|
||||||
|
presence_penalty: 'Limits tokens based on whether they appear in the output or not.',
|
||||||
|
frequency_penalty: 'Limits tokens based on how often they appear in the output.',
|
||||||
|
dry_multiplier: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling multiplier.',
|
||||||
|
dry_base: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling base value.',
|
||||||
|
dry_allowed_length: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the allowed length for DRY sampling.',
|
||||||
|
dry_penalty_last_n: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.',
|
||||||
|
max_tokens: 'The maximum number of token per output.',
|
||||||
|
custom: '', // custom json-stringified object
|
||||||
|
};
|
||||||
|
// config keys having numeric value (i.e. temperature, top_k, top_p, etc)
|
||||||
|
const CONFIG_NUMERIC_KEYS = Object.entries(CONFIG_DEFAULT).filter(e => isNumeric(e[1])).map(e => e[0]);
|
||||||
|
// list of themes supported by daisyui
|
||||||
|
const THEMES = ['light', 'dark', 'cupcake', 'bumblebee', 'emerald', 'corporate', 'synthwave', 'retro', 'cyberpunk', 'valentine', 'halloween', 'garden', 'forest', 'aqua', 'lofi', 'pastel', 'fantasy', 'wireframe', 'black', 'luxury', 'dracula', 'cmyk', 'autumn', 'business', 'acid', 'lemonade', 'night', 'coffee', 'winter', 'dim', 'nord', 'sunset'];
|
||||||
|
|
||||||
|
// markdown support
|
||||||
|
const VueMarkdown = defineComponent(
|
||||||
|
(props) => {
|
||||||
|
const md = shallowRef(new MarkdownIt({ breaks: true }));
|
||||||
|
const origFenchRenderer = md.value.renderer.rules.fence;
|
||||||
|
md.value.renderer.rules.fence = (tokens, idx, ...args) => {
|
||||||
|
const content = tokens[idx].content;
|
||||||
|
const origRendered = origFenchRenderer(tokens, idx, ...args);
|
||||||
|
return `<div class="relative my-4">
|
||||||
|
<div class="text-right sticky top-4 mb-2 mr-2 h-0">
|
||||||
|
<button class="badge btn-mini" onclick="copyStr(${escapeAttr(JSON.stringify(content))})">📋 Copy</button>
|
||||||
|
</div>
|
||||||
|
${origRendered}
|
||||||
|
</div>`;
|
||||||
|
};
|
||||||
|
window.copyStr = copyStr;
|
||||||
|
const content = computed(() => md.value.render(props.source));
|
||||||
|
return () => h("div", { innerHTML: content.value });
|
||||||
|
},
|
||||||
|
{ props: ["source"] }
|
||||||
|
);
|
||||||
|
|
||||||
|
// input field to be used by settings modal
|
||||||
|
const SettingsModalShortInput = defineComponent({
|
||||||
|
template: document.getElementById('settings-modal-short-input').innerHTML,
|
||||||
|
props: {
|
||||||
|
label: { type: String, required: false },
|
||||||
|
configKey: String,
|
||||||
|
configDefault: Object,
|
||||||
|
configInfo: Object,
|
||||||
|
modelValue: [Object, String, Number],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// coversations is stored in localStorage
|
||||||
|
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
|
||||||
|
// convId is a string prefixed with 'conv-'
|
||||||
|
const StorageUtils = {
|
||||||
|
// manage conversations
|
||||||
|
getAllConversations() {
|
||||||
|
const res = [];
|
||||||
|
for (const key in localStorage) {
|
||||||
|
if (key.startsWith('conv-')) {
|
||||||
|
res.push(JSON.parse(localStorage.getItem(key)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res.sort((a, b) => b.lastModified - a.lastModified);
|
||||||
|
return res;
|
||||||
|
},
|
||||||
|
// can return null if convId does not exist
|
||||||
|
getOneConversation(convId) {
|
||||||
|
return JSON.parse(localStorage.getItem(convId) || 'null');
|
||||||
|
},
|
||||||
|
// if convId does not exist, create one
|
||||||
|
appendMsg(convId, msg) {
|
||||||
|
if (msg.content === null) return;
|
||||||
|
const conv = StorageUtils.getOneConversation(convId) || {
|
||||||
|
id: convId,
|
||||||
|
lastModified: Date.now(),
|
||||||
|
messages: [],
|
||||||
|
};
|
||||||
|
conv.messages.push(msg);
|
||||||
|
conv.lastModified = Date.now();
|
||||||
|
localStorage.setItem(convId, JSON.stringify(conv));
|
||||||
|
},
|
||||||
|
getNewConvId() {
|
||||||
|
return `conv-${Date.now()}`;
|
||||||
|
},
|
||||||
|
remove(convId) {
|
||||||
|
localStorage.removeItem(convId);
|
||||||
|
},
|
||||||
|
filterAndKeepMsgs(convId, predicate) {
|
||||||
|
const conv = StorageUtils.getOneConversation(convId);
|
||||||
|
if (!conv) return;
|
||||||
|
conv.messages = conv.messages.filter(predicate);
|
||||||
|
conv.lastModified = Date.now();
|
||||||
|
localStorage.setItem(convId, JSON.stringify(conv));
|
||||||
|
},
|
||||||
|
popMsg(convId) {
|
||||||
|
const conv = StorageUtils.getOneConversation(convId);
|
||||||
|
if (!conv) return;
|
||||||
|
const msg = conv.messages.pop();
|
||||||
|
conv.lastModified = Date.now();
|
||||||
|
if (conv.messages.length === 0) {
|
||||||
|
StorageUtils.remove(convId);
|
||||||
|
} else {
|
||||||
|
localStorage.setItem(convId, JSON.stringify(conv));
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
},
|
||||||
|
|
||||||
|
// manage config
|
||||||
|
getConfig() {
|
||||||
|
const savedVal = JSON.parse(localStorage.getItem('config') || '{}');
|
||||||
|
// to prevent breaking changes in the future, we always provide default value for missing keys
|
||||||
|
return {
|
||||||
|
...CONFIG_DEFAULT,
|
||||||
|
...savedVal,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
setConfig(config) {
|
||||||
|
localStorage.setItem('config', JSON.stringify(config));
|
||||||
|
},
|
||||||
|
getTheme() {
|
||||||
|
return localStorage.getItem('theme') || 'auto';
|
||||||
|
},
|
||||||
|
setTheme(theme) {
|
||||||
|
if (theme === 'auto') {
|
||||||
|
localStorage.removeItem('theme');
|
||||||
|
} else {
|
||||||
|
localStorage.setItem('theme', theme);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// scroll to bottom of chat messages
|
||||||
|
// if requiresNearBottom is true, only auto-scroll if user is near bottom
|
||||||
|
const chatScrollToBottom = (requiresNearBottom) => {
|
||||||
|
const msgListElem = document.getElementById('messages-list');
|
||||||
|
const spaceToBottom = msgListElem.scrollHeight - msgListElem.scrollTop - msgListElem.clientHeight;
|
||||||
|
if (!requiresNearBottom || (spaceToBottom < 100)) {
|
||||||
|
setTimeout(() => msgListElem.scrollTo({ top: msgListElem.scrollHeight }), 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const mainApp = createApp({
|
||||||
|
components: {
|
||||||
|
VueMarkdown,
|
||||||
|
SettingsModalShortInput,
|
||||||
|
},
|
||||||
|
data() {
|
||||||
|
return {
|
||||||
|
conversations: StorageUtils.getAllConversations(),
|
||||||
|
messages: [], // { id: number, role: 'user' | 'assistant', content: string }
|
||||||
|
viewingConvId: StorageUtils.getNewConvId(),
|
||||||
|
inputMsg: '',
|
||||||
|
isGenerating: false,
|
||||||
|
pendingMsg: null, // the on-going message from assistant
|
||||||
|
stopGeneration: () => {},
|
||||||
|
selectedTheme: StorageUtils.getTheme(),
|
||||||
|
config: StorageUtils.getConfig(),
|
||||||
|
showConfigDialog: false,
|
||||||
|
editingMsg: null,
|
||||||
|
// const
|
||||||
|
themes: THEMES,
|
||||||
|
configDefault: {...CONFIG_DEFAULT},
|
||||||
|
configInfo: {...CONFIG_INFO},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
computed: {},
|
||||||
|
mounted() {
|
||||||
|
document.getElementById('app').classList.remove('opacity-0'); // show app
|
||||||
|
// scroll to the bottom when the pending message height is updated
|
||||||
|
const pendingMsgElem = document.getElementById('pending-msg');
|
||||||
|
const resizeObserver = new ResizeObserver(() => {
|
||||||
|
if (this.isGenerating) chatScrollToBottom(true);
|
||||||
|
});
|
||||||
|
resizeObserver.observe(pendingMsgElem);
|
||||||
|
},
|
||||||
|
methods: {
|
||||||
|
hideSidebar() {
|
||||||
|
document.getElementById('toggle-drawer').checked = false;
|
||||||
|
},
|
||||||
|
setSelectedTheme(theme) {
|
||||||
|
this.selectedTheme = theme;
|
||||||
|
StorageUtils.setTheme(theme);
|
||||||
|
},
|
||||||
|
newConversation() {
|
||||||
|
if (this.isGenerating) return;
|
||||||
|
this.viewingConvId = StorageUtils.getNewConvId();
|
||||||
|
this.editingMsg = null;
|
||||||
|
this.fetchMessages();
|
||||||
|
chatScrollToBottom();
|
||||||
|
this.hideSidebar();
|
||||||
|
},
|
||||||
|
setViewingConv(convId) {
|
||||||
|
if (this.isGenerating) return;
|
||||||
|
this.viewingConvId = convId;
|
||||||
|
this.editingMsg = null;
|
||||||
|
this.fetchMessages();
|
||||||
|
chatScrollToBottom();
|
||||||
|
this.hideSidebar();
|
||||||
|
},
|
||||||
|
deleteConv(convId) {
|
||||||
|
if (this.isGenerating) return;
|
||||||
|
if (window.confirm('Are you sure to delete this conversation?')) {
|
||||||
|
StorageUtils.remove(convId);
|
||||||
|
if (this.viewingConvId === convId) {
|
||||||
|
this.viewingConvId = StorageUtils.getNewConvId();
|
||||||
|
this.editingMsg = null;
|
||||||
|
}
|
||||||
|
this.fetchConversation();
|
||||||
|
this.fetchMessages();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
downloadConv(convId) {
|
||||||
|
const conversation = StorageUtils.getOneConversation(convId);
|
||||||
|
if (!conversation) {
|
||||||
|
alert('Conversation not found.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const conversationJson = JSON.stringify(conversation, null, 2);
|
||||||
|
const blob = new Blob([conversationJson], { type: 'application/json' });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = `conversation_${convId}.json`;
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
document.body.removeChild(a);
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
},
|
||||||
|
async sendMessage() {
|
||||||
|
if (!this.inputMsg) return;
|
||||||
|
const currConvId = this.viewingConvId;
|
||||||
|
|
||||||
|
StorageUtils.appendMsg(currConvId, {
|
||||||
|
id: Date.now(),
|
||||||
|
role: 'user',
|
||||||
|
content: this.inputMsg,
|
||||||
|
});
|
||||||
|
this.fetchConversation();
|
||||||
|
this.fetchMessages();
|
||||||
|
this.inputMsg = '';
|
||||||
|
this.editingMsg = null;
|
||||||
|
this.generateMessage(currConvId);
|
||||||
|
chatScrollToBottom();
|
||||||
|
},
|
||||||
|
async generateMessage(currConvId) {
|
||||||
|
if (this.isGenerating) return;
|
||||||
|
this.pendingMsg = { id: Date.now()+1, role: 'assistant', content: null };
|
||||||
|
this.isGenerating = true;
|
||||||
|
this.editingMsg = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const abortController = new AbortController();
|
||||||
|
this.stopGeneration = () => abortController.abort();
|
||||||
|
const params = {
|
||||||
|
messages: [
|
||||||
|
{ role: 'system', content: this.config.systemMessage },
|
||||||
|
...this.messages,
|
||||||
|
],
|
||||||
|
stream: true,
|
||||||
|
cache_prompt: true,
|
||||||
|
samplers: this.config.samplers,
|
||||||
|
temperature: this.config.temperature,
|
||||||
|
dynatemp_range: this.config.dynatemp_range,
|
||||||
|
dynatemp_exponent: this.config.dynatemp_exponent,
|
||||||
|
top_k: this.config.top_k,
|
||||||
|
top_p: this.config.top_p,
|
||||||
|
min_p: this.config.min_p,
|
||||||
|
typical_p: this.config.typical_p,
|
||||||
|
xtc_probability: this.config.xtc_probability,
|
||||||
|
xtc_threshold: this.config.xtc_threshold,
|
||||||
|
repeat_last_n: this.config.repeat_last_n,
|
||||||
|
repeat_penalty: this.config.repeat_penalty,
|
||||||
|
presence_penalty: this.config.presence_penalty,
|
||||||
|
frequency_penalty: this.config.frequency_penalty,
|
||||||
|
dry_multiplier: this.config.dry_multiplier,
|
||||||
|
dry_base: this.config.dry_base,
|
||||||
|
dry_allowed_length: this.config.dry_allowed_length,
|
||||||
|
dry_penalty_last_n: this.config.dry_penalty_last_n,
|
||||||
|
max_tokens: this.config.max_tokens,
|
||||||
|
...(this.config.custom.length ? JSON.parse(this.config.custom) : {}),
|
||||||
|
...(this.config.apiKey ? { api_key: this.config.apiKey } : {}),
|
||||||
|
};
|
||||||
|
const config = {
|
||||||
|
controller: abortController,
|
||||||
|
api_url: BASE_URL,
|
||||||
|
endpoint: '/chat/completions',
|
||||||
|
};
|
||||||
|
for await (const chunk of llama(prompt, params, config)) {
|
||||||
|
const stop = chunk.data.stop;
|
||||||
|
const addedContent = chunk.data.choices[0].delta.content;
|
||||||
|
const lastContent = this.pendingMsg.content || '';
|
||||||
|
if (addedContent) {
|
||||||
|
this.pendingMsg = {
|
||||||
|
id: this.pendingMsg.id,
|
||||||
|
role: 'assistant',
|
||||||
|
content: lastContent + addedContent,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StorageUtils.appendMsg(currConvId, this.pendingMsg);
|
||||||
|
this.fetchConversation();
|
||||||
|
this.fetchMessages();
|
||||||
|
setTimeout(() => document.getElementById('msg-input').focus(), 1);
|
||||||
|
} catch (error) {
|
||||||
|
if (error.name === 'AbortError') {
|
||||||
|
// user stopped the generation via stopGeneration() function
|
||||||
|
StorageUtils.appendMsg(currConvId, this.pendingMsg);
|
||||||
|
this.fetchConversation();
|
||||||
|
this.fetchMessages();
|
||||||
|
} else {
|
||||||
|
console.error(error);
|
||||||
|
alert(error);
|
||||||
|
// pop last user message
|
||||||
|
const lastUserMsg = StorageUtils.popMsg(currConvId);
|
||||||
|
this.inputMsg = lastUserMsg ? lastUserMsg.content : '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.pendingMsg = null;
|
||||||
|
this.isGenerating = false;
|
||||||
|
this.stopGeneration = () => {};
|
||||||
|
this.fetchMessages();
|
||||||
|
chatScrollToBottom();
|
||||||
|
},
|
||||||
|
|
||||||
|
// message actions
|
||||||
|
regenerateMsg(msg) {
|
||||||
|
if (this.isGenerating) return;
|
||||||
|
// TODO: somehow keep old history (like how ChatGPT has different "tree"). This can be done by adding "sub-conversations" with "subconv-" prefix, and new message will have a list of subconvIds
|
||||||
|
const currConvId = this.viewingConvId;
|
||||||
|
StorageUtils.filterAndKeepMsgs(currConvId, (m) => m.id < msg.id);
|
||||||
|
this.fetchConversation();
|
||||||
|
this.fetchMessages();
|
||||||
|
this.generateMessage(currConvId);
|
||||||
|
},
|
||||||
|
copyMsg(msg) {
|
||||||
|
copyStr(msg.content);
|
||||||
|
},
|
||||||
|
editUserMsgAndRegenerate(msg) {
|
||||||
|
if (this.isGenerating) return;
|
||||||
|
const currConvId = this.viewingConvId;
|
||||||
|
const newContent = msg.content;
|
||||||
|
this.editingMsg = null;
|
||||||
|
StorageUtils.filterAndKeepMsgs(currConvId, (m) => m.id < msg.id);
|
||||||
|
StorageUtils.appendMsg(currConvId, {
|
||||||
|
id: Date.now(),
|
||||||
|
role: 'user',
|
||||||
|
content: newContent,
|
||||||
|
});
|
||||||
|
this.fetchConversation();
|
||||||
|
this.fetchMessages();
|
||||||
|
this.generateMessage(currConvId);
|
||||||
|
},
|
||||||
|
|
||||||
|
// settings dialog methods
|
||||||
|
closeAndSaveConfigDialog() {
|
||||||
|
try {
|
||||||
|
if (this.config.custom.length) JSON.parse(this.config.custom);
|
||||||
|
} catch (error) {
|
||||||
|
alert('Invalid JSON for custom config. Please either fix it or leave it empty.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const key of CONFIG_NUMERIC_KEYS) {
|
||||||
|
if (isNaN(this.config[key]) || this.config[key].toString().trim().length === 0) {
|
||||||
|
alert(`Invalid number for ${key} (expected an integer or a float)`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.config[key] = parseFloat(this.config[key]);
|
||||||
|
}
|
||||||
|
this.showConfigDialog = false;
|
||||||
|
StorageUtils.setConfig(this.config);
|
||||||
|
},
|
||||||
|
closeAndDiscardConfigDialog() {
|
||||||
|
this.showConfigDialog = false;
|
||||||
|
this.config = StorageUtils.getConfig();
|
||||||
|
},
|
||||||
|
resetConfigDialog() {
|
||||||
|
if (window.confirm('Are you sure to reset all settings?')) {
|
||||||
|
this.config = {...CONFIG_DEFAULT};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// sync state functions
|
||||||
|
fetchConversation() {
|
||||||
|
this.conversations = StorageUtils.getAllConversations();
|
||||||
|
},
|
||||||
|
fetchMessages() {
|
||||||
|
this.messages = StorageUtils.getOneConversation(this.viewingConvId)?.messages ?? [];
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
mainApp.config.errorHandler = alert;
|
||||||
|
try {
|
||||||
|
mainApp.mount('#app');
|
||||||
|
} catch (err) {
|
||||||
|
console.error(err);
|
||||||
|
document.getElementById('app').innerHTML = `<div style="margin:2em auto">
|
||||||
|
Failed to start app. Please try clearing localStorage and try again.<br/>
|
||||||
|
<br/>
|
||||||
|
<button class="btn" onClick="localStorage.clear(); window.location.reload();">Clear localStorage</button>
|
||||||
|
</div>`;
|
||||||
|
}
|
26
examples/server/webui/src/styles.css
Normal file
26
examples/server/webui/src/styles.css
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
@tailwind base;
|
||||||
|
@tailwind components;
|
||||||
|
@tailwind utilities;
|
||||||
|
|
||||||
|
.markdown {
|
||||||
|
h1, h2, h3, h4, h5, h6, ul, ol, li { all: revert; }
|
||||||
|
pre {
|
||||||
|
@apply whitespace-pre-wrap rounded-lg p-2;
|
||||||
|
border: 1px solid currentColor;
|
||||||
|
}
|
||||||
|
/* TODO: fix markdown table */
|
||||||
|
}
|
||||||
|
|
||||||
|
.show-on-hover {
|
||||||
|
@apply md:opacity-0 md:group-hover:opacity-100;
|
||||||
|
}
|
||||||
|
.btn-mini {
|
||||||
|
@apply cursor-pointer hover:shadow-md;
|
||||||
|
}
|
||||||
|
.chat-screen { max-width: 900px; }
|
||||||
|
|
||||||
|
.chat-bubble-base-300 {
|
||||||
|
--tw-bg-opacity: 1;
|
||||||
|
--tw-text-opacity: 1;
|
||||||
|
@apply bg-base-300 text-base-content;
|
||||||
|
}
|
16
examples/server/webui/tailwind.config.js
Normal file
16
examples/server/webui/tailwind.config.js
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
/** @type {import('tailwindcss').Config} */
|
||||||
|
export default {
|
||||||
|
content: [
|
||||||
|
"./index.html",
|
||||||
|
"./src/**/*.{js,ts,jsx,tsx}",
|
||||||
|
],
|
||||||
|
theme: {
|
||||||
|
extend: {},
|
||||||
|
},
|
||||||
|
plugins: [
|
||||||
|
require('daisyui'),
|
||||||
|
],
|
||||||
|
daisyui: {
|
||||||
|
themes: ['light', 'dark', 'cupcake', 'bumblebee', 'emerald', 'corporate', 'synthwave', 'retro', 'cyberpunk', 'valentine', 'halloween', 'garden', 'forest', 'aqua', 'lofi', 'pastel', 'fantasy', 'wireframe', 'black', 'luxury', 'dracula', 'cmyk', 'autumn', 'business', 'acid', 'lemonade', 'night', 'coffee', 'winter', 'dim', 'nord', 'sunset'],
|
||||||
|
}
|
||||||
|
}
|
36
examples/server/webui/vite.config.js
Normal file
36
examples/server/webui/vite.config.js
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
|
||||||
|
import { viteSingleFile } from 'vite-plugin-singlefile';
|
||||||
|
import path from 'path';
|
||||||
|
import fs from 'fs';
|
||||||
|
|
||||||
|
const GUIDE_FOR_FRONTEND = `
|
||||||
|
<!--
|
||||||
|
This is a single file build of the frontend.
|
||||||
|
It is automatically generated by the build process.
|
||||||
|
Do not edit this file directly.
|
||||||
|
To make changes, refer to the "Web UI" section in the README.
|
||||||
|
-->
|
||||||
|
`.trim();
|
||||||
|
|
||||||
|
export default {
|
||||||
|
plugins: [
|
||||||
|
viteSingleFile(),
|
||||||
|
(function llamaCppPlugin() {
|
||||||
|
let config;
|
||||||
|
return {
|
||||||
|
name: 'llamacpp:build',
|
||||||
|
apply: 'build',
|
||||||
|
async configResolved(_config) {
|
||||||
|
config = _config;
|
||||||
|
},
|
||||||
|
writeBundle() {
|
||||||
|
const outputIndexHtml = path.join(config.build.outDir, 'index.html');
|
||||||
|
const content = fs.readFileSync(outputIndexHtml, 'utf-8');
|
||||||
|
|
||||||
|
const targetOutputFile = path.join(config.build.outDir, '../../public/index.html');
|
||||||
|
fs.writeFileSync(targetOutputFile, GUIDE_FOR_FRONTEND + '\n' + content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
],
|
||||||
|
};
|
@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||||
|
|
||||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||||
}
|
}
|
||||||
|
@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
float kqmax_new_j = kqmax_new_arr[j];
|
float kqmax_new_j = kqmax_new_arr[j];
|
||||||
|
|
||||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||||
}
|
}
|
||||||
|
@ -310,14 +310,14 @@ void ggml_aligned_free(void * ptr, size_t size);
|
|||||||
// FP16 to FP32 conversion
|
// FP16 to FP32 conversion
|
||||||
|
|
||||||
#if defined(__ARM_NEON)
|
#if defined(__ARM_NEON)
|
||||||
#ifdef _MSC_VER
|
#if defined(_MSC_VER) || (defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11)
|
||||||
typedef uint16_t ggml_fp16_internal_t;
|
typedef uint16_t ggml_fp16_internal_t;
|
||||||
#else
|
#else
|
||||||
typedef __fp16 ggml_fp16_internal_t;
|
typedef __fp16 ggml_fp16_internal_t;
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__ARM_NEON) && !defined(_MSC_VER)
|
#if defined(__ARM_NEON) && !defined(_MSC_VER) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11)
|
||||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||||
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
||||||
|
|
||||||
|
@ -192,6 +192,30 @@ typedef struct {
|
|||||||
int16_t r3;
|
int16_t r3;
|
||||||
} ggml_metal_kargs_mul_mv;
|
} ggml_metal_kargs_mul_mv;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne01;
|
||||||
|
int32_t ne02;
|
||||||
|
uint64_t nb00;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int32_t ne10;
|
||||||
|
int32_t ne11;
|
||||||
|
int32_t ne12;
|
||||||
|
uint64_t nb10;
|
||||||
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
uint64_t nb13;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
int16_t r2;
|
||||||
|
int16_t r3;
|
||||||
|
int16_t nsg;
|
||||||
|
int16_t nxpsg;
|
||||||
|
int16_t r1ptg;
|
||||||
|
} ggml_metal_kargs_mul_mv_ext;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t nei0;
|
int32_t nei0;
|
||||||
int32_t nei1;
|
int32_t nei1;
|
||||||
|
@ -175,6 +175,46 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
||||||
@ -266,6 +306,8 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
||||||
@ -350,6 +392,7 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||||
|
|
||||||
GGML_METAL_KERNEL_TYPE_COUNT
|
GGML_METAL_KERNEL_TYPE_COUNT
|
||||||
};
|
};
|
||||||
@ -699,6 +742,46 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
||||||
@ -790,6 +873,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
||||||
@ -872,6 +957,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||||
}
|
}
|
||||||
@ -989,6 +1075,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SQRT:
|
case GGML_OP_SQRT:
|
||||||
@ -1001,6 +1088,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
return has_simdgroup_reduction;
|
return has_simdgroup_reduction;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
|
||||||
|
case GGML_OP_ARGMAX:
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
return true;
|
return true;
|
||||||
@ -1928,30 +2016,180 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
int ne11_mm_min = 4;
|
const int ne11_mm_min = 4;
|
||||||
|
|
||||||
#if 0
|
// first try to use small-batch mat-mv kernels
|
||||||
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
// these should be efficient for BS [2, ~8]
|
||||||
// these numbers do not translate to other devices or model sizes
|
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
|
||||||
// TODO: need to find a better approach
|
(
|
||||||
if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
|
(
|
||||||
switch (src0t) {
|
(
|
||||||
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
src0t == GGML_TYPE_F16 || // TODO: helper function
|
||||||
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
src0t == GGML_TYPE_Q4_0 ||
|
||||||
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
src0t == GGML_TYPE_Q4_1 ||
|
||||||
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
src0t == GGML_TYPE_Q5_0 ||
|
||||||
|
src0t == GGML_TYPE_Q5_1 ||
|
||||||
|
src0t == GGML_TYPE_Q8_0 ||
|
||||||
|
src0t == GGML_TYPE_IQ4_NL ||
|
||||||
|
false) && (ne11 >= 2 && ne11 <= 8)
|
||||||
|
) ||
|
||||||
|
(
|
||||||
|
(
|
||||||
|
src0t == GGML_TYPE_Q4_K ||
|
||||||
|
src0t == GGML_TYPE_Q5_K ||
|
||||||
|
src0t == GGML_TYPE_Q6_K ||
|
||||||
|
false) && (ne11 >= 4 && ne11 <= 8)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
// TODO: determine the optimal parameters based on grid utilization
|
||||||
|
// I still don't know why we should not always use the maximum available threads:
|
||||||
|
//
|
||||||
|
// nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
|
||||||
|
//
|
||||||
|
// my current hypothesis is that the work grid is not evenly divisible for different nsg
|
||||||
|
// values and there can be some tail effects when nsg is high. need to confirm this
|
||||||
|
//
|
||||||
|
const int nsg = 2; // num simdgroups per threadgroup
|
||||||
|
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
||||||
|
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
||||||
|
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
||||||
|
int r1ptg = 4; // num src1 rows per threadgroup
|
||||||
|
|
||||||
|
// note: not sure how optimal are those across all different hardware. there might be someting cleverer
|
||||||
|
switch (ne11) {
|
||||||
|
case 2:
|
||||||
|
r1ptg = 2; break;
|
||||||
|
case 3:
|
||||||
|
case 6:
|
||||||
|
r1ptg = 3; break;
|
||||||
|
case 4:
|
||||||
|
case 7:
|
||||||
|
case 8:
|
||||||
|
r1ptg = 4; break;
|
||||||
|
case 5:
|
||||||
|
r1ptg = 5; break;
|
||||||
|
};
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
switch (r1ptg) {
|
||||||
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: // not tested yet
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
|
||||||
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
|
||||||
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
default: GGML_ABORT("not implemented");
|
||||||
default: ne11_mm_min = 1; break;
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
}
|
}
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
ggml_metal_kargs_mul_mv_ext args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne10 =*/ ne10,
|
||||||
|
/*.ne11 =*/ ne11,
|
||||||
|
/*.ne12 =*/ ne12,
|
||||||
|
/*.nb10 =*/ nb10,
|
||||||
|
/*.nb11 =*/ nb11,
|
||||||
|
/*.nb12 =*/ nb12,
|
||||||
|
/*.nb13 =*/ nb13,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.r2 =*/ r2,
|
||||||
|
/*.r3 =*/ r3,
|
||||||
|
/*.nsg =*/ nsg,
|
||||||
|
/*.nxpsg =*/ nxpsg,
|
||||||
|
/*.r1ptg =*/ r1ptg,
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
|
||||||
|
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
|
} else
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
@ -2908,6 +3146,49 @@ static void ggml_metal_encode_node(
|
|||||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||||
|
|
||||||
|
const int32_t IC = src1->ne[1];
|
||||||
|
const int32_t IL = src1->ne[0];
|
||||||
|
|
||||||
|
const int32_t K = src0->ne[0];
|
||||||
|
|
||||||
|
const int32_t OL = dst->ne[0];
|
||||||
|
const int32_t OC = dst->ne[1];
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline;
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32: {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F16: {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
|
||||||
|
} break;
|
||||||
|
default: GGML_ABORT("fatal error");
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
[encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
|
||||||
|
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
|
||||||
|
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
|
||||||
|
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
|
||||||
|
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
@ -3567,6 +3848,31 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ARGMAX:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||||
|
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
[encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||||
|
@ -47,6 +47,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|||||||
reg = (type4x4)(*src);
|
reg = (type4x4)(*src);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
||||||
|
reg = (type4)(*(src + il));
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(GGML_METAL_USE_BF16)
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
||||||
@ -73,6 +78,21 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
|||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||||
|
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||||
|
const float d2 = d1 / 256.f;
|
||||||
|
const float md = -8.h * xb->d;
|
||||||
|
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||||
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
|
||||||
|
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||||
@ -92,6 +112,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||||
|
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||||
|
const float d2 = d1 / 256.f;
|
||||||
|
const float m = xb->m;
|
||||||
|
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||||
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
|
||||||
|
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
|
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||||
@ -124,6 +159,36 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
|||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||||
|
const float d = xb->d;
|
||||||
|
const float md = -16.h * xb->d;
|
||||||
|
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
||||||
|
|
||||||
|
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||||
|
|
||||||
|
const int x_mv = (il/4) ? 4 : 0;
|
||||||
|
|
||||||
|
const int gh_mv = (il/4) ? 12 : 0;
|
||||||
|
const int gh_bk = (il/4) ? 0 : 4;
|
||||||
|
|
||||||
|
for (int ii = 0; ii < 2; ii++) {
|
||||||
|
int i = 2*(il%4) + ii;
|
||||||
|
|
||||||
|
// extract the 5-th bits for x0 and x1
|
||||||
|
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||||
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||||
|
|
||||||
|
// combine the 4-bits from qs with the 5th bit
|
||||||
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||||
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||||
|
|
||||||
|
reg[2*ii + 0] = d * x0 + md;
|
||||||
|
reg[2*ii + 1] = d * x1 + md;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
|
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||||
@ -156,10 +221,40 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
|||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||||
|
const float d = xb->d;
|
||||||
|
const float m = xb->m;
|
||||||
|
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
||||||
|
|
||||||
|
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||||
|
|
||||||
|
const int x_mv = (il/4) ? 4 : 0;
|
||||||
|
|
||||||
|
const int gh_mv = (il/4) ? 12 : 0;
|
||||||
|
const int gh_bk = (il/4) ? 0 : 4;
|
||||||
|
|
||||||
|
for (int ii = 0; ii < 2; ii++) {
|
||||||
|
int i = 2*(il%4) + ii;
|
||||||
|
|
||||||
|
// extract the 5-th bits for x0 and x1
|
||||||
|
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||||
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||||
|
|
||||||
|
// combine the 4-bits from qs with the 5th bit
|
||||||
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||||
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||||
|
|
||||||
|
reg[2*ii + 0] = d * x0 + m;
|
||||||
|
reg[2*ii + 1] = d * x1 + m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
const half d = xb->d;
|
const float d = xb->d;
|
||||||
|
|
||||||
float4x4 reg_f;
|
float4x4 reg_f;
|
||||||
|
|
||||||
@ -170,6 +265,16 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
|
||||||
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
|
const float d = xb->d;
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||||
const float d = xb->d;
|
const float d = xb->d;
|
||||||
@ -469,6 +574,19 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
||||||
|
const float d = xb->d;
|
||||||
|
uint32_t aux32;
|
||||||
|
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||||
|
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
|
||||||
|
reg[0] = d * kvalues_iq4nl_f[q8[0]];
|
||||||
|
reg[1] = d * kvalues_iq4nl_f[q8[1]];
|
||||||
|
reg[2] = d * kvalues_iq4nl_f[q8[2]];
|
||||||
|
reg[3] = d * kvalues_iq4nl_f[q8[3]];
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
||||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||||
@ -1248,6 +1366,63 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_argmax(
|
||||||
|
device const void * x,
|
||||||
|
device int32_t * dst,
|
||||||
|
constant int64_t & ncols,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
threadgroup float * shared_maxval [[threadgroup(0)]],
|
||||||
|
threadgroup int32_t * shared_argmax [[threadgroup(1)]],
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
|
||||||
|
|
||||||
|
float lmax = -INFINITY;
|
||||||
|
int32_t larg = -1;
|
||||||
|
|
||||||
|
for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
|
||||||
|
if (x_row[i00] > lmax) {
|
||||||
|
lmax = x_row[i00];
|
||||||
|
larg = i00;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the argmax value in the block
|
||||||
|
float max_val = simd_max(lmax);
|
||||||
|
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
|
||||||
|
|
||||||
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
shared_maxval[tiisg] = -INFINITY;
|
||||||
|
shared_argmax[tiisg] = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shared_maxval[sgitg] = max_val;
|
||||||
|
shared_argmax[sgitg] = arg_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
max_val = shared_maxval[tiisg];
|
||||||
|
arg_val = shared_argmax[tiisg];
|
||||||
|
|
||||||
|
float max_val_reduced = simd_max(max_val);
|
||||||
|
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
|
||||||
|
|
||||||
|
dst[tgpig] = arg_val_reduced;
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[tgpig] = arg_val;
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_norm(
|
kernel void kernel_norm(
|
||||||
constant ggml_metal_kargs_norm & args,
|
constant ggml_metal_kargs_norm & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
@ -1752,6 +1927,301 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|||||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mat-vec kernel processing in chunks of float4
|
||||||
|
// chpb - chunks per quantization block
|
||||||
|
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
||||||
|
void kernel_mul_mv_ext_q4_f32_impl(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short chpt = 4; // chunks per thread
|
||||||
|
|
||||||
|
//const short nxpsg = (32);
|
||||||
|
const short nypsg = (32/nxpsg);
|
||||||
|
|
||||||
|
const short tx = tiisg%nxpsg;
|
||||||
|
const short ty = tiisg/nxpsg;
|
||||||
|
|
||||||
|
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
||||||
|
const int i11 = tgpig.y*r1ptg;
|
||||||
|
const int i1m = tgpig.z;
|
||||||
|
|
||||||
|
const int i12 = i1m%args.ne12;
|
||||||
|
const int i13 = i1m/args.ne12;
|
||||||
|
|
||||||
|
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
|
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
||||||
|
|
||||||
|
device const float4 * y4[r1ptg];
|
||||||
|
|
||||||
|
for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
|
||||||
|
|
||||||
|
short cch = tx%chpb; // current chunk index
|
||||||
|
|
||||||
|
for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
|
||||||
|
float4 lx[chpt];
|
||||||
|
|
||||||
|
#pragma unroll(chpt)
|
||||||
|
for (short ch = 0; ch < chpt; ++ch) {
|
||||||
|
deq_t4(xq, cch, lx[ch]);
|
||||||
|
|
||||||
|
cch += nxpsg;
|
||||||
|
if (cch >= chpb) {
|
||||||
|
xq += cch/chpb;
|
||||||
|
cch %= chpb;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll(chpt)
|
||||||
|
for (short ch = 0; ch < chpt; ++ch) {
|
||||||
|
#pragma unroll(r1ptg)
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll(r1ptg)
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
y4[ir1] += chpt*nxpsg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce only the threads in each row
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
if (nxpsg >= 32) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 16) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 8) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 4) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 2) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
//sumf[ir1] = simd_sum(sumf[ir1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tx == 0) {
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
|
||||||
|
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
|
||||||
|
|
||||||
|
if (i01 < args.ne01) {
|
||||||
|
dst_f32[i01] = sumf[ir1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mat-vec kernel processing in chunks of float4x4
|
||||||
|
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
||||||
|
void kernel_mul_mv_ext_q4x4_f32_impl(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short chpt = 1;
|
||||||
|
|
||||||
|
//const short nxpsg = (32);
|
||||||
|
const short nypsg = (32/nxpsg);
|
||||||
|
|
||||||
|
const short tx = tiisg%nxpsg;
|
||||||
|
const short ty = tiisg/nxpsg;
|
||||||
|
|
||||||
|
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
||||||
|
const int i11 = tgpig.y*r1ptg;
|
||||||
|
const int i1m = tgpig.z;
|
||||||
|
|
||||||
|
const int i12 = i1m%args.ne12;
|
||||||
|
const int i13 = i1m/args.ne12;
|
||||||
|
|
||||||
|
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
|
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
||||||
|
|
||||||
|
device const float4x4 * y4x4[r1ptg];
|
||||||
|
|
||||||
|
for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
|
||||||
|
|
||||||
|
short cch = tx%chpb;
|
||||||
|
|
||||||
|
for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
|
||||||
|
float4x4 lx[chpt];
|
||||||
|
|
||||||
|
#pragma unroll(chpt)
|
||||||
|
for (short ch = 0; ch < chpt; ++ch) {
|
||||||
|
deq_t4x4(xq, cch, lx[ch]);
|
||||||
|
|
||||||
|
cch += nxpsg;
|
||||||
|
if (cch >= chpb) {
|
||||||
|
xq += cch/chpb;
|
||||||
|
cch %= chpb;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll(chpt)
|
||||||
|
for (short ch = 0; ch < chpt; ++ch) {
|
||||||
|
#pragma unroll(r1ptg)
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
sumf[ir1] +=
|
||||||
|
dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
|
||||||
|
dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
|
||||||
|
dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
|
||||||
|
dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll(r1ptg)
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
y4x4[ir1] += chpt*nxpsg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
if (nxpsg >= 32) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 16) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 8) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 4) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 2) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
//sumf[ir1] = simd_sum(sumf[ir1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tx == 0) {
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
|
||||||
|
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
|
||||||
|
|
||||||
|
if (i01 < args.ne01) {
|
||||||
|
dst_f32[i01] = sumf[ir1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchers needed for compile-time nxpsg
|
||||||
|
// epb - elements per quantization block
|
||||||
|
template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
|
||||||
|
kernel void kernel_mul_mv_ext_q4_f32_disp(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
switch (args.nxpsg) {
|
||||||
|
case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
|
||||||
|
kernel void kernel_mul_mv_ext_q4x4_f32_disp(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
switch (args.nxpsg) {
|
||||||
|
case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
|
||||||
|
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
||||||
|
|
||||||
#define N_MV_T_T 4
|
#define N_MV_T_T 4
|
||||||
|
|
||||||
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||||
@ -2258,6 +2728,79 @@ kernel void kernel_im2col_ext(
|
|||||||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
||||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
||||||
|
|
||||||
|
typedef void (conv_transpose_1d_t)(
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant int32_t & IC,
|
||||||
|
constant int32_t & IL,
|
||||||
|
constant int32_t & K,
|
||||||
|
constant int32_t & s0,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void kernel_conv_transpose_1d(
|
||||||
|
device const T * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant int32_t & IC,
|
||||||
|
constant int32_t & IL,
|
||||||
|
constant int32_t & K,
|
||||||
|
constant int32_t & s0,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||||
|
|
||||||
|
float v = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t c = 0; c < IC; c++) {
|
||||||
|
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
||||||
|
const int32_t input_offset = c * IL;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < IL; i++) {
|
||||||
|
if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
|
||||||
|
v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
|
||||||
|
|
||||||
|
dst_ptr[0] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
|
||||||
|
kernel void kernel_conv_transpose_1d<float>(
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant int32_t & IC,
|
||||||
|
constant int32_t & IL,
|
||||||
|
constant int32_t & K,
|
||||||
|
constant int32_t & s0,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
|
||||||
|
kernel void kernel_conv_transpose_1d<half>(
|
||||||
|
device const half * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant int32_t & IC,
|
||||||
|
constant int32_t & IL,
|
||||||
|
constant int32_t & K,
|
||||||
|
constant int32_t & s0,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]);
|
||||||
|
|
||||||
kernel void kernel_upscale_f32(
|
kernel void kernel_upscale_f32(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
|
@ -68,7 +68,8 @@ else()
|
|||||||
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
|
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
|
||||||
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
|
||||||
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
|
add_compile_definitions(GGML_SYCL_NVIDIA)
|
||||||
|
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas)
|
||||||
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
||||||
if (NOT GGML_SYCL_DEVICE_ARCH)
|
if (NOT GGML_SYCL_DEVICE_ARCH)
|
||||||
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
|
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
|
||||||
|
@ -1689,9 +1689,14 @@ namespace dpct
|
|||||||
auto data_a = get_memory<const Ta>(a);
|
auto data_a = get_memory<const Ta>(a);
|
||||||
auto data_b = get_memory<const Tb>(b);
|
auto data_b = get_memory<const Tb>(b);
|
||||||
auto data_c = get_memory<Tc>(c);
|
auto data_c = get_memory<Tc>(c);
|
||||||
oneapi::mkl::blas::column_major::gemm(
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
||||||
data_b, ldb, beta_value, data_c, ldc);
|
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
||||||
|
beta_value, data_c, ldc);
|
||||||
|
#else
|
||||||
|
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
||||||
|
beta_value, data_c, ldc);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename VecT, class BinaryOperation, class = void>
|
template <typename VecT, class BinaryOperation, class = void>
|
||||||
@ -1754,14 +1759,22 @@ namespace dpct
|
|||||||
matrix_info->ld_info[2] = ldc;
|
matrix_info->ld_info[2] = ldc;
|
||||||
matrix_info->groupsize_info = batch_size;
|
matrix_info->groupsize_info = batch_size;
|
||||||
|
|
||||||
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
||||||
matrix_info->size_info, matrix_info->size_info + 1,
|
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
||||||
matrix_info->size_info + 2, matrix_info->value_info,
|
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
||||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
||||||
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
||||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
&(matrix_info->groupsize_info));
|
||||||
|
#else
|
||||||
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
|
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
||||||
|
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
||||||
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||||
|
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
||||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||||
|
#endif
|
||||||
|
|
||||||
q.submit([&](sycl::handler &cgh)
|
q.submit([&](sycl::handler &cgh)
|
||||||
{
|
{
|
||||||
@ -1783,10 +1796,16 @@ namespace dpct
|
|||||||
auto data_a = get_memory<const Ta>(a);
|
auto data_a = get_memory<const Ta>(a);
|
||||||
auto data_b = get_memory<const Tb>(b);
|
auto data_b = get_memory<const Tb>(b);
|
||||||
auto data_c = get_memory<Tc>(c);
|
auto data_c = get_memory<Tc>(c);
|
||||||
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
oneapi::mkl::blas::column_major::gemm_batch(
|
oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
|
||||||
stride_a, data_b, ldb, stride_b, beta_value,
|
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
||||||
data_c, ldc, stride_c, batch_size);
|
batch_size);
|
||||||
|
#else
|
||||||
|
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
||||||
|
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
||||||
|
stride_c, batch_size);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
@ -2573,12 +2573,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
#if !GGML_SYCL_DNNL
|
#if !GGML_SYCL_DNNL
|
||||||
|
# ifdef GGML_SYCL_NVIDIA
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||||
*stream, oneapi::mkl::transpose::trans,
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
|
||||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
|
||||||
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
|
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||||
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
# else
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||||
|
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
|
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
||||||
dst_dd_i, ldc)));
|
dst_dd_i, ldc)));
|
||||||
|
# endif
|
||||||
#else
|
#else
|
||||||
auto dnnl_stream = ctx.stream_dnnl(stream);
|
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||||
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
||||||
|
@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// Perform matrix multiplication using oneMKL GEMM
|
// Perform matrix multiplication using oneMKL GEMM
|
||||||
oneapi::mkl::blas::column_major::gemm(*stream,
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
oneapi::mkl::transpose::nontrans, src1_op,
|
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
|
||||||
ne0, ne1, ne01,
|
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
|
||||||
alpha,
|
ne00, src1_d, ldb, beta, dst_d, ne0);
|
||||||
src0_d, ne00,
|
#else
|
||||||
src1_d, ldb,
|
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
|
||||||
beta,
|
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
||||||
dst_d, ne0);
|
#endif
|
||||||
}
|
}
|
||||||
catch (sycl::exception const& exc) {
|
catch (sycl::exception const& exc) {
|
||||||
std::cerr << exc.what() << std::endl;
|
std::cerr << exc.what() << std::endl;
|
||||||
|
@ -165,6 +165,7 @@ struct vk_device_struct {
|
|||||||
vk_queue transfer_queue;
|
vk_queue transfer_queue;
|
||||||
bool single_queue;
|
bool single_queue;
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
|
uint32_t shader_core_count;
|
||||||
bool uma;
|
bool uma;
|
||||||
|
|
||||||
size_t idx;
|
size_t idx;
|
||||||
@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||||
@ -1610,11 +1611,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
||||||
|
|
||||||
bool maintenance4_support = false;
|
bool maintenance4_support = false;
|
||||||
|
bool sm_builtins = false;
|
||||||
|
|
||||||
// Check if maintenance4 is supported
|
// Check if maintenance4 is supported
|
||||||
for (const auto& properties : ext_props) {
|
for (const auto& properties : ext_props) {
|
||||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||||
maintenance4_support = true;
|
maintenance4_support = true;
|
||||||
|
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
||||||
|
sm_builtins = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1622,11 +1626,21 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
vk::PhysicalDeviceMaintenance3Properties props3;
|
vk::PhysicalDeviceMaintenance3Properties props3;
|
||||||
vk::PhysicalDeviceMaintenance4Properties props4;
|
vk::PhysicalDeviceMaintenance4Properties props4;
|
||||||
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
||||||
|
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
|
||||||
props2.pNext = &props3;
|
props2.pNext = &props3;
|
||||||
props3.pNext = &subgroup_props;
|
props3.pNext = &subgroup_props;
|
||||||
|
|
||||||
|
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
|
||||||
|
|
||||||
if (maintenance4_support) {
|
if (maintenance4_support) {
|
||||||
subgroup_props.pNext = &props4;
|
last_struct->pNext = (VkBaseOutStructure *)&props4;
|
||||||
|
last_struct = (VkBaseOutStructure *)&props4;
|
||||||
}
|
}
|
||||||
|
if (sm_builtins) {
|
||||||
|
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
|
||||||
|
last_struct = (VkBaseOutStructure *)&sm_props;
|
||||||
|
}
|
||||||
|
|
||||||
device->physical_device.getProperties2(&props2);
|
device->physical_device.getProperties2(&props2);
|
||||||
device->properties = props2.properties;
|
device->properties = props2.properties;
|
||||||
|
|
||||||
@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
device->vendor_id = device->properties.vendorID;
|
device->vendor_id = device->properties.vendorID;
|
||||||
device->subgroup_size = subgroup_props.subgroupSize;
|
device->subgroup_size = subgroup_props.subgroupSize;
|
||||||
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
|
if (sm_builtins) {
|
||||||
|
device->shader_core_count = sm_props.shaderSMCount;
|
||||||
|
} else {
|
||||||
|
device->shader_core_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
bool fp16_storage = false;
|
bool fp16_storage = false;
|
||||||
bool fp16_compute = false;
|
bool fp16_compute = false;
|
||||||
@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
|
|||||||
dst->device->device.resetFences({ dst->device->fence });
|
dst->device->device.resetFences({ dst->device->fence });
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
|
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
|
||||||
// if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
|
|
||||||
// return 4;
|
|
||||||
// }
|
|
||||||
|
|
||||||
return 1;
|
uint32_t split_k = 1;
|
||||||
|
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
|
||||||
|
// If k is 'large' and the SMs will fill less than halfway, use split_k.
|
||||||
|
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
|
||||||
|
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
|
||||||
|
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
|
||||||
|
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
|
||||||
|
// Clamp to 2 or 4
|
||||||
|
split_k = std::min(split_k, 4u);
|
||||||
|
if (split_k == 3) {
|
||||||
|
split_k = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
|
return split_k;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
||||||
@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||||
|
|
||||||
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
||||||
|
|
||||||
|
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||||
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
||||||
@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
if (dryrun) {
|
if (dryrun) {
|
||||||
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
||||||
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
||||||
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
|
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
|
||||||
if (
|
if (
|
||||||
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
|
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||||
|
@ -5,7 +5,9 @@
|
|||||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
|
||||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ne;
|
uint ne;
|
||||||
@ -13,17 +15,34 @@ layout (push_constant) uniform parameter {
|
|||||||
} p;
|
} p;
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint idx = gl_GlobalInvocationID.x;
|
// Each invocation handles four consecutive components
|
||||||
|
const uint idx = gl_GlobalInvocationID.x * 4;
|
||||||
|
|
||||||
if (idx >= p.ne) {
|
if (idx >= p.ne) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if all four components are in bounds and aligned,
|
||||||
|
// then use vector loads
|
||||||
|
if (idx + 3 < p.ne && (p.ne % 4) == 0) {
|
||||||
|
vec4 result = vec4(0.0f);
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
|
||||||
|
result += data_a4[(i * p.ne + idx) / 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
data_d4[idx / 4] = result;
|
||||||
|
} else {
|
||||||
|
[[unroll]] for (uint j = 0; j < 4; ++j) {
|
||||||
|
if (idx + j < p.ne) {
|
||||||
float result = 0.0f;
|
float result = 0.0f;
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
|
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
|
||||||
result += data_a[i * p.ne + idx];
|
result += data_a[i * p.ne + idx + j];
|
||||||
}
|
}
|
||||||
|
|
||||||
data_d[idx] = result;
|
data_d[idx + j] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ Terminals support the full range of Unicode. Unicode characters can be specified
|
|||||||
|
|
||||||
Character ranges can be negated with `^`:
|
Character ranges can be negated with `^`:
|
||||||
```
|
```
|
||||||
single-line ::= [^\n]+ "\n"`
|
single-line ::= [^\n]+ "\n"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sequences and Alternatives
|
## Sequences and Alternatives
|
||||||
|
@ -991,7 +991,7 @@ extern "C" {
|
|||||||
int32_t length);
|
int32_t length);
|
||||||
|
|
||||||
// Get list of built-in chat templates
|
// Get list of built-in chat templates
|
||||||
int32_t llama_chat_builtin_templates(const char ** output, size_t len);
|
LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Sampling API
|
// Sampling API
|
||||||
|
@ -73,7 +73,6 @@ while read c; do
|
|||||||
src/ggml*.h \
|
src/ggml*.h \
|
||||||
src/ggml*.c \
|
src/ggml*.c \
|
||||||
src/ggml*.cpp \
|
src/ggml*.cpp \
|
||||||
src/ggml-amx/* \
|
|
||||||
src/ggml-blas/* \
|
src/ggml-blas/* \
|
||||||
src/ggml-cann/* \
|
src/ggml-cann/* \
|
||||||
src/ggml-cpu/* \
|
src/ggml-cpu/* \
|
||||||
@ -124,7 +123,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||||||
# src/ggml*.c -> ggml/src/ggml*.c
|
# src/ggml*.c -> ggml/src/ggml*.c
|
||||||
# src/ggml*.cpp -> ggml/src/ggml*.cpp
|
# src/ggml*.cpp -> ggml/src/ggml*.cpp
|
||||||
# src/ggml*.h -> ggml/src/ggml*.h
|
# src/ggml*.h -> ggml/src/ggml*.h
|
||||||
# src/ggml-amx/* -> ggml/src/ggml-amx/*
|
|
||||||
# src/ggml-blas/* -> ggml/src/ggml-blas/*
|
# src/ggml-blas/* -> ggml/src/ggml-blas/*
|
||||||
# src/ggml-cann/* -> ggml/src/ggml-cann/*
|
# src/ggml-cann/* -> ggml/src/ggml-cann/*
|
||||||
# src/ggml-cpu/* -> ggml/src/ggml-cpu/*
|
# src/ggml-cpu/* -> ggml/src/ggml-cpu/*
|
||||||
@ -151,7 +149,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-amx\//\1ggml\/src\/ggml-amx\//g' \
|
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \
|
||||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \
|
-e 's/([[:space:]]|[ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \
|
||||||
|
@ -1 +1 @@
|
|||||||
c598cbe30621251e80acbcf3b601589a37c17f4d
|
b903ffe79daf18c0aaacbebe44a7b93a6b8d0982
|
||||||
|
@ -7,7 +7,6 @@ cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake
|
|||||||
cp -rpv ../ggml/src/ggml*.c ./ggml/src/
|
cp -rpv ../ggml/src/ggml*.c ./ggml/src/
|
||||||
cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/
|
cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/
|
||||||
cp -rpv ../ggml/src/ggml*.h ./ggml/src/
|
cp -rpv ../ggml/src/ggml*.h ./ggml/src/
|
||||||
cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/
|
|
||||||
cp -rpv ../ggml/src/ggml-blas/* ./ggml/src/ggml-blas/
|
cp -rpv ../ggml/src/ggml-blas/* ./ggml/src/ggml-blas/
|
||||||
cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
|
cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
|
||||||
cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/
|
cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/
|
||||||
|
@ -3460,13 +3460,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
||||||
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_argmax());
|
test_cases.emplace_back(new test_count_equal());
|
||||||
|
|
||||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
||||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
||||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
|
||||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));
|
||||||
test_cases.emplace_back(new test_count_equal());
|
|
||||||
|
|
||||||
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
|
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
|
||||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
|
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
|
||||||
@ -3572,6 +3573,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||||
|
|
||||||
|
for (int i = 1; i < 9; ++i) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
for (ggml_type type_a : base_types) {
|
for (ggml_type type_a : base_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
Loading…
Reference in New Issue
Block a user