mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
Merge branch 'master' into mpi
This commit is contained in:
commit
0f557c2ac4
33
.devops/full-cuda.Dockerfile
Normal file
33
.devops/full-cuda.Dockerfile
Normal file
@ -0,0 +1,33 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=11.7.1
|
||||
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} as build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
ARG CUDA_DOCKER_ARCH=all
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential python3 python3-pip
|
||||
|
||||
COPY requirements.txt requirements.txt
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
# Set nvcc architecture
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
# Enable cuBLAS
|
||||
ENV LLAMA_CUBLAS=1
|
||||
|
||||
RUN make
|
||||
|
||||
ENTRYPOINT ["/app/.devops/tools.sh"]
|
32
.devops/main-cuda.Dockerfile
Normal file
32
.devops/main-cuda.Dockerfile
Normal file
@ -0,0 +1,32 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=11.7.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
# Target the CUDA runtime image
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} as build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
ARG CUDA_DOCKER_ARCH=all
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
# Set nvcc architecture
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
# Enable cuBLAS
|
||||
ENV LLAMA_CUBLAS=1
|
||||
|
||||
RUN make
|
||||
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} as runtime
|
||||
|
||||
COPY --from=build /app/main /main
|
||||
|
||||
ENTRYPOINT [ "/main" ]
|
14
.github/workflows/build.yml
vendored
14
.github/workflows/build.yml
vendored
@ -16,7 +16,10 @@ on:
|
||||
paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu']
|
||||
|
||||
env:
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
GGML_NLOOP: 3
|
||||
GGML_NITER: 1
|
||||
GGML_N_THREADS: 1
|
||||
|
||||
jobs:
|
||||
ubuntu-focal-make:
|
||||
@ -64,7 +67,7 @@ jobs:
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest --verbose
|
||||
ctest --verbose --timeout 900
|
||||
|
||||
ubuntu-latest-cmake-sanitizer:
|
||||
runs-on: ubuntu-latest
|
||||
@ -99,7 +102,7 @@ jobs:
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest --verbose
|
||||
ctest --verbose --timeout 900
|
||||
|
||||
ubuntu-latest-cmake-mpi:
|
||||
runs-on: ubuntu-latest
|
||||
@ -175,10 +178,11 @@ jobs:
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest --verbose
|
||||
ctest --verbose --timeout 900
|
||||
|
||||
windows-latest-cmake:
|
||||
runs-on: windows-latest
|
||||
|
||||
env:
|
||||
OPENBLAS_VERSION: 0.3.23
|
||||
OPENCL_VERSION: 2023.04.17
|
||||
@ -277,7 +281,7 @@ jobs:
|
||||
if: ${{ matrix.build != 'clblast' && (matrix.build != 'avx512' || env.HAS_AVX512F == '1') }} # Test AVX-512 only when possible
|
||||
run: |
|
||||
cd build
|
||||
ctest -C Release --verbose
|
||||
ctest -C Release --verbose --timeout 900
|
||||
|
||||
- name: Get commit hash
|
||||
id: commit
|
||||
|
@ -218,6 +218,9 @@ if (LLAMA_BLAS)
|
||||
message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
|
||||
add_compile_options(${BLAS_LINKER_FLAGS})
|
||||
add_compile_definitions(GGML_USE_OPENBLAS)
|
||||
if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel"))
|
||||
add_compile_definitions(GGML_BLAS_USE_MKL)
|
||||
endif()
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES})
|
||||
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
|
||||
|
||||
|
8
Makefile
8
Makefile
@ -172,7 +172,12 @@ ifdef LLAMA_CUBLAS
|
||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
||||
OBJS += ggml-cuda.o
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler
|
||||
ifdef CUDA_DOCKER_ARCH
|
||||
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
|
||||
else
|
||||
NVCCFLAGS += -arch=native
|
||||
endif # CUDA_DOCKER_ARCH
|
||||
ifdef LLAMA_CUDA_FORCE_DMMV
|
||||
NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV
|
||||
endif # LLAMA_CUDA_FORCE_DMMV
|
||||
@ -196,6 +201,7 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
|
||||
else
|
||||
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
|
||||
endif
|
||||
|
||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
||||
endif # LLAMA_CUBLAS
|
||||
|
41
README.md
41
README.md
@ -734,7 +734,7 @@ export LD_LIBRARY_PATH=/vendor/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
For easy and swift re-execution, consider documenting this final part in a .sh script file. This will enable you to rerun the process with minimal hassle.
|
||||
|
||||
Place your desired model into the `/llama.cpp/models/` directory and execute the `./main (...)` script.
|
||||
Place your desired model into the `~/llama.cpp/models/` directory and execute the `./main (...)` script.
|
||||
|
||||
### Docker
|
||||
|
||||
@ -770,6 +770,38 @@ or with a light image:
|
||||
docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:light -m /models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -n 512
|
||||
```
|
||||
|
||||
### Docker With CUDA
|
||||
|
||||
Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) properly installed on Linux, or is using a GPU enabled cloud, `cuBLAS` should be accessible inside the container.
|
||||
|
||||
#### Building Locally
|
||||
|
||||
```bash
|
||||
docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-cuda -f .devops/main-cuda.Dockerfile .
|
||||
```
|
||||
|
||||
You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture.
|
||||
|
||||
The defaults are:
|
||||
|
||||
- `CUDA_VERSION` set to `11.7.1`
|
||||
- `CUDA_DOCKER_ARCH` set to `all`
|
||||
|
||||
The resulting images, are essentially the same as the non-CUDA images:
|
||||
|
||||
1. `local/llama.cpp:full-cuda`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
|
||||
2. `local/llama.cpp:light-cuda`: This image only includes the main executable file.
|
||||
|
||||
#### Usage
|
||||
|
||||
After building locally, Usage is similar to the non-CUDA examples, but you'll need to add the `--gpus` flag. You will also want to use the `--n-gpu-layers` flag.
|
||||
|
||||
```bash
|
||||
docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run -m /models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1
|
||||
docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1
|
||||
```
|
||||
|
||||
### Contributing
|
||||
|
||||
- Contributors can open PRs
|
||||
@ -790,5 +822,10 @@ docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:light -m /mode
|
||||
|
||||
### Docs
|
||||
|
||||
- [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks)
|
||||
- [main](./examples/main/README.md)
|
||||
- [server](./examples/server/README.md)
|
||||
- [embd-input](./examples/embd-input/README.md)
|
||||
- [jeopardy](./examples/jeopardy/README.md)
|
||||
- [BLIS](./docs/BLIS.md)
|
||||
- [Performance troubleshooting](./docs/token_generation_performance_tips.md)
|
||||
- [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks)
|
||||
|
@ -828,6 +828,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
|
||||
|
||||
|
||||
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
||||
'BF16': DT_BF16,
|
||||
'F16': DT_F16,
|
||||
'F32': DT_F32,
|
||||
'I32': DT_I32,
|
||||
|
@ -31,6 +31,17 @@ float frand_normal(struct random_normal_distribution * rnd) {
|
||||
return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
|
||||
}
|
||||
|
||||
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
}
|
||||
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
struct ggml_tensor * randomize_tensor(
|
||||
struct ggml_tensor * tensor,
|
||||
int ndims,
|
||||
@ -1569,6 +1580,8 @@ int main(int argc, char ** argv) {
|
||||
int n_tokens = model.hparams.n_ctx;
|
||||
int n_vocab = model.hparams.n_vocab;
|
||||
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
for (int ex=0; ex<n_examples; ++ex) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ compute_size,
|
||||
@ -1586,7 +1599,6 @@ int main(int argc, char ** argv) {
|
||||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = 1;
|
||||
|
||||
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
|
||||
|
||||
@ -1595,7 +1607,7 @@ int main(int argc, char ** argv) {
|
||||
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
|
||||
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
|
||||
float error_before_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
@ -1611,7 +1623,7 @@ int main(int argc, char ** argv) {
|
||||
ggml_opt(ctx0, opt_params_lbfgs, e);
|
||||
//
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
@ -1659,13 +1671,12 @@ int main(int argc, char ** argv) {
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = 1;
|
||||
|
||||
int n_past = 0;
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
|
||||
|
||||
ggml_build_forward_expand(&gf, logits);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
|
||||
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||
@ -1687,10 +1698,11 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
print_matrix(model.tok_embeddings);
|
||||
|
||||
printf("done\n");
|
||||
|
||||
// ggml_free(kv_self.ctx);
|
||||
// ggml_free(model_lora.ctx);
|
||||
ggml_free(model.ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -20,6 +20,17 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
}
|
||||
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
float tensor_sum_elements(const ggml_tensor * tensor) {
|
||||
float sum = 0;
|
||||
if (tensor->type==GGML_TYPE_F32) {
|
||||
@ -159,13 +170,14 @@ int main(int argc, char ** argv) {
|
||||
// printf("Creating compute graph\n");
|
||||
struct ggml_cgraph gf = ggml_build_forward(m11xm2);
|
||||
|
||||
gf.n_threads=benchmark_params.n_threads;
|
||||
printf("cgraph->n_threads=%i\n",gf.n_threads);
|
||||
printf("n_threads=%i\n", benchmark_params.n_threads);
|
||||
|
||||
TENSOR_DUMP(m11);
|
||||
TENSOR_DUMP(m2);
|
||||
|
||||
ggml_graph_compute(ctx, &gf);
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
ggml_graph_compute_helper(work_buffer, &gf, benchmark_params.n_threads);
|
||||
|
||||
TENSOR_DUMP(gf.nodes[0]);
|
||||
|
||||
@ -187,7 +199,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// printf("Creating compute graph\n");
|
||||
struct ggml_cgraph gf31 = ggml_build_forward(q31);
|
||||
gf31.n_threads=benchmark_params.n_threads;
|
||||
|
||||
// Set up a second graph computation to make sure we override the CPU cache lines
|
||||
// printf("Creating new tensor q12 & Running quantize\n");
|
||||
@ -199,8 +210,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
//printf("Creating compute graph\n");
|
||||
struct ggml_cgraph gf32 = ggml_build_forward(q32);
|
||||
gf32.n_threads=benchmark_params.n_threads;
|
||||
printf("cgraph->n_threads=%i\n",gf31.n_threads);
|
||||
printf("n_threads=%i\n", benchmark_params.n_threads);
|
||||
|
||||
const int dimx = sizex;
|
||||
const int dimy = sizey;
|
||||
@ -221,14 +231,15 @@ int main(int argc, char ** argv) {
|
||||
|
||||
long long int start = ggml_time_us();
|
||||
//printf("Running ggml_graph_compute\n");
|
||||
ggml_graph_compute(ctx, &gf31);
|
||||
ggml_graph_compute_helper(work_buffer, &gf31, benchmark_params.n_threads);
|
||||
|
||||
long long int stop = ggml_time_us();
|
||||
long long int usec = stop-start;
|
||||
double gflops = (double)(flops_per_matrix)/usec/1000.0;
|
||||
gflops_sum += gflops;
|
||||
printf("%9i;%8i;%6i;%6i;%6i;%15lli;%18lli;%10.2f\n",
|
||||
i,
|
||||
gf31.n_threads,
|
||||
benchmark_params.n_threads,
|
||||
sizex, sizey, sizez, flops_per_matrix,
|
||||
usec,gflops);
|
||||
|
||||
@ -253,7 +264,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// Running a different graph computation to make sure we override the CPU cache lines
|
||||
ggml_graph_compute(ctx, &gf32);
|
||||
ggml_graph_compute_helper(work_buffer, &gf32, benchmark_params.n_threads);
|
||||
}
|
||||
printf("\n");
|
||||
printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));
|
||||
|
@ -418,6 +418,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
|
||||
if (escape_prompt) {
|
||||
process_escapes(params.prompt);
|
||||
process_escapes(params.input_prefix);
|
||||
process_escapes(params.input_suffix);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -35,10 +35,9 @@ int main(int argc, char ** argv) {
|
||||
struct ggml_context * ctx_eval = NULL;
|
||||
|
||||
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
|
||||
gf.n_threads = 1;
|
||||
|
||||
// this allocates all Metal resources and memory buffers
|
||||
auto * ctx_metal = ggml_metal_init();
|
||||
auto * ctx_metal = ggml_metal_init(1);
|
||||
|
||||
const size_t max_size_data = ggml_get_max_tensor_size(ctx_data);
|
||||
const size_t max_size_eval = ggml_get_max_tensor_size(ctx_eval);
|
||||
|
@ -60,6 +60,17 @@ float frand_uniform(struct random_uniform_distribution * rnd) {
|
||||
return rnd->rd(rnd->gen);
|
||||
}
|
||||
|
||||
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
}
|
||||
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
|
||||
float scale = 1.0f; // xavier
|
||||
switch (tensor->n_dims) {
|
||||
@ -1426,11 +1437,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||
|
||||
gf->n_nodes = 0;
|
||||
gf->n_leafs = 0;
|
||||
gf->work_size = 0;
|
||||
gf->perf_runs = 0;
|
||||
gf->perf_cycles = 0;
|
||||
gf->perf_time_us = 0;
|
||||
gf->work = NULL;
|
||||
|
||||
const auto & hparams = model->hparams;
|
||||
//const int n_ctx = hparams.n_ctx;
|
||||
@ -3162,6 +3171,7 @@ int main(int argc, char ** argv) {
|
||||
printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx));
|
||||
// ggml_print_tensor_objects(model.ctx);
|
||||
|
||||
// TODO: use std::vector<uint8_t> intead of "new"
|
||||
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
|
||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||
|
||||
@ -3183,6 +3193,8 @@ int main(int argc, char ** argv) {
|
||||
GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
|
||||
}
|
||||
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
printf("%s: begin training\n", __func__);
|
||||
|
||||
for (int ex = 0; ex < params.n_examples; ++ex) {
|
||||
@ -3217,9 +3229,6 @@ int main(int argc, char ** argv) {
|
||||
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
||||
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
||||
|
||||
// ggml_cgraph gf = {};
|
||||
gf->n_threads = params.n_threads;
|
||||
gb->n_threads = params.n_threads;
|
||||
|
||||
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||
|
||||
@ -3248,7 +3257,7 @@ int main(int argc, char ** argv) {
|
||||
*gb = ggml_build_backward(ctx0, gf, true);
|
||||
}
|
||||
|
||||
ggml_graph_compute(ctx0, gf);
|
||||
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
||||
|
||||
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
||||
|
||||
@ -3272,7 +3281,7 @@ int main(int argc, char ** argv) {
|
||||
model.train_samples += n_batch;
|
||||
model.train_tokens += n_batch * n_tokens;
|
||||
|
||||
ggml_graph_compute(ctx0, gf);
|
||||
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(loss, 0);
|
||||
|
||||
@ -3354,13 +3363,12 @@ int main(int argc, char ** argv) {
|
||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = params.n_threads;
|
||||
|
||||
int n_past = 0;
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
|
||||
|
||||
ggml_build_forward_expand(&gf, logits);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, params.n_threads);
|
||||
|
||||
//struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||
//struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||
@ -3386,6 +3394,7 @@ int main(int argc, char ** argv) {
|
||||
delete[] compute_addr;
|
||||
delete[] compute_buf_0;
|
||||
delete[] compute_buf_1;
|
||||
|
||||
llama_free(lctx);
|
||||
llama_free_model(lmodel);
|
||||
ggml_free(model.ctx);
|
||||
|
93
ggml-cuda.cu
93
ggml-cuda.cu
@ -59,8 +59,8 @@ typedef float2 dfloat2;
|
||||
#endif //GGML_CUDA_DMMV_F16
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
||||
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
|
||||
typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
typedef void (*ggml_cuda_op_t)(
|
||||
@ -131,7 +131,7 @@ typedef struct {
|
||||
} block_q8_1;
|
||||
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
|
||||
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
|
||||
|
||||
//================================= k-quants
|
||||
|
||||
@ -208,6 +208,7 @@ typedef struct {
|
||||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
#define CUDA_ADD_BLOCK_SIZE 256
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
@ -407,7 +408,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
||||
|
||||
//================================== k-quants
|
||||
|
||||
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
||||
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
@ -440,7 +441,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
||||
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
||||
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
@ -504,7 +505,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
||||
}
|
||||
#endif
|
||||
|
||||
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
||||
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -544,7 +545,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
||||
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -590,7 +591,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
||||
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -634,7 +635,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
@ -742,7 +743,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
||||
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
@ -846,7 +847,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
||||
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
@ -949,7 +950,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
|
||||
static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
@ -1053,7 +1054,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
||||
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
@ -1171,7 +1172,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
|
||||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
|
||||
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
@ -1180,10 +1181,10 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
|
||||
|
||||
block_q8_1 * y = (block_q8_1 *) vy;
|
||||
|
||||
const int ib = i / QK8_0; // block index
|
||||
const int iqs = i % QK8_0; // quant index
|
||||
const int ib = i / QK8_1; // block index
|
||||
const int iqs = i % QK8_1; // quant index
|
||||
|
||||
const float xi = x[i];
|
||||
const float xi = i < ndata ? x[i] : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
@ -1207,7 +1208,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
@ -1227,7 +1228,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
|
||||
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
|
||||
|
||||
@ -1252,7 +1253,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons
|
||||
#endif // __CUDA_ARCH__ >= 600
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
|
||||
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
|
||||
|
||||
@ -1277,7 +1278,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons
|
||||
#endif // __CUDA_ARCH__ >= 600
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
|
||||
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
|
||||
|
||||
@ -1312,7 +1313,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons
|
||||
#endif // __CUDA_ARCH__ >= 600
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
|
||||
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
|
||||
|
||||
@ -1346,7 +1347,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons
|
||||
#endif // __CUDA_ARCH__ >= 600
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
|
||||
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
|
||||
|
||||
@ -1366,7 +1367,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons
|
||||
}
|
||||
|
||||
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
@ -1404,7 +1405,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
|
||||
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
// qk = quantized weights per x block
|
||||
// qr = number of quantized weights per data value in x block
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
@ -1471,7 +1472,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
@ -1518,7 +1519,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
||||
}
|
||||
|
||||
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
||||
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
|
||||
const int row_stride_x, const int channel_stride_x) {
|
||||
|
||||
const half * x = (const half *) vx;
|
||||
@ -1714,9 +1715,9 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
|
||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
|
||||
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
|
||||
quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, ndata, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
@ -2355,16 +2356,15 @@ inline void ggml_cuda_op_mul_mat_vec(
|
||||
src0->type == GGML_TYPE_Q5_1 ||
|
||||
src0->type == GGML_TYPE_Q8_0;
|
||||
|
||||
// The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
|
||||
// However, they have bad performance with Pascal cards.
|
||||
// Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
|
||||
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
|
||||
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
|
||||
#endif
|
||||
|
||||
if (use_mul_mat_vec_q) {
|
||||
int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1;
|
||||
padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
|
||||
size_t as;
|
||||
void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as);
|
||||
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
|
||||
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
|
||||
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@ -3108,7 +3108,11 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
||||
|
||||
void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
||||
int nrows = ggml_nrows(tensor);
|
||||
|
||||
const int64_t ne0 = tensor->ne[0];
|
||||
|
||||
const size_t nb1 = tensor->nb[1];
|
||||
|
||||
ggml_backend backend = tensor->backend;
|
||||
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
@ -3137,11 +3141,24 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
||||
int64_t nrows_split = row_high - row_low;
|
||||
|
||||
const size_t offset_split = row_low*nb1;
|
||||
const size_t size = ggml_nbytes_split(tensor, nrows_split);
|
||||
size_t size = ggml_nbytes_split(tensor, nrows_split);
|
||||
const size_t original_size = size;
|
||||
|
||||
void * buf;
|
||||
// pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
|
||||
if (ne0 % MATRIX_ROW_PADDING != 0) {
|
||||
size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
|
||||
* ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
|
||||
}
|
||||
|
||||
char * buf;
|
||||
CUDA_CHECK(cudaMalloc(&buf, size));
|
||||
void * buf_host = (char*)data + offset_split;
|
||||
char * buf_host = (char*)data + offset_split;
|
||||
|
||||
// set padding to 0 to avoid possible NaN values
|
||||
if (size > original_size) {
|
||||
CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
|
||||
}
|
||||
|
||||
|
||||
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
|
||||
|
||||
|
@ -34,9 +34,13 @@ extern "C" {
|
||||
|
||||
struct ggml_metal_context;
|
||||
|
||||
struct ggml_metal_context * ggml_metal_init(void);
|
||||
// number of command buffers to use
|
||||
struct ggml_metal_context * ggml_metal_init(int n_cb);
|
||||
void ggml_metal_free(struct ggml_metal_context * ctx);
|
||||
|
||||
// set the number of command buffers to use
|
||||
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
|
||||
|
||||
// creates a mapping between a host memory buffer and a device memory buffer
|
||||
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
|
||||
// - the mapping is used during computation to determine the arguments of the compute kernels
|
||||
|
11
ggml-metal.m
11
ggml-metal.m
@ -25,6 +25,8 @@ struct ggml_metal_buffer {
|
||||
};
|
||||
|
||||
struct ggml_metal_context {
|
||||
int n_cb;
|
||||
|
||||
float * logits;
|
||||
|
||||
id<MTLDevice> device;
|
||||
@ -86,11 +88,12 @@ static NSString * const msl_library_source = @"see metal.metal";
|
||||
@implementation GGMLMetalClass
|
||||
@end
|
||||
|
||||
struct ggml_metal_context * ggml_metal_init(void) {
|
||||
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
fprintf(stderr, "%s: allocating\n", __func__);
|
||||
|
||||
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
||||
|
||||
ctx->n_cb = n_cb;
|
||||
ctx->device = MTLCreateSystemDefaultDevice();
|
||||
ctx->queue = [ctx->device newCommandQueue];
|
||||
ctx->n_buffers = 0;
|
||||
@ -208,6 +211,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
||||
ctx->n_cb = n_cb;
|
||||
}
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
||||
// Metal buffer based on the host memory pointer
|
||||
@ -354,7 +361,7 @@ void ggml_metal_graph_compute(
|
||||
// create multiple command buffers and enqueue them
|
||||
// then, we encode the graph into the command buffers in parallel
|
||||
|
||||
const int n_cb = gf->n_threads;
|
||||
const int n_cb = ctx->n_cb;
|
||||
|
||||
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
||||
|
||||
|
@ -653,13 +653,17 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
\n#if K_QUANTS_PER_ITERATION == 1\n
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
||||
const int is = 0;
|
||||
#else
|
||||
|
||||
\n#else\n
|
||||
|
||||
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
||||
const int is = in / 4;
|
||||
#endif
|
||||
|
||||
\n#endif\n
|
||||
|
||||
const int ql_offset = 64*im + l0;
|
||||
const int qh_offset = 32*im + l0;
|
||||
const int s_offset = 8*im + is;
|
||||
@ -676,7 +680,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
|
||||
|
||||
const float d = vload_half(0, &x[i].d);
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
\n#if K_QUANTS_PER_ITERATION == 1\n
|
||||
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
|
||||
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
|
||||
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
|
||||
@ -686,7 +690,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
|
||||
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
|
||||
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#else
|
||||
\n#else\n
|
||||
float sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
|
||||
@ -695,7 +699,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
|
||||
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#endif
|
||||
\n#endif\n
|
||||
|
||||
}
|
||||
|
||||
|
766
ggml.c
766
ggml.c
@ -247,7 +247,11 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
||||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
#elif defined(GGML_USE_OPENBLAS)
|
||||
#if defined(GGML_BLAS_USE_MKL)
|
||||
#include <mkl.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
#include "ggml-cuda.h"
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
@ -4583,14 +4587,13 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
||||
/*.src0 =*/ NULL,
|
||||
/*.src1 =*/ NULL,
|
||||
/*.opt =*/ { NULL },
|
||||
/*.n_tasks =*/ 0,
|
||||
/*.perf_runs =*/ 0,
|
||||
/*.perf_cycles =*/ 0,
|
||||
/*.perf_time_us =*/ 0,
|
||||
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
||||
/*.name =*/ { 0 },
|
||||
/*.extra =*/ NULL,
|
||||
/*.pad =*/ { 0 },
|
||||
/*.padding =*/ { 0 },
|
||||
};
|
||||
|
||||
// TODO: this should not be needed as long as we don't rely on aligned SIMD loads
|
||||
@ -10718,8 +10721,6 @@ static void ggml_compute_forward_mul_mat(
|
||||
|
||||
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
assert(ne00 % 32 == 0);
|
||||
|
||||
for (int64_t ic = 0; ic < ne11; ++ic) {
|
||||
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
||||
}
|
||||
@ -15772,9 +15773,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
||||
struct ggml_cgraph result = {
|
||||
/*.n_nodes =*/ 0,
|
||||
/*.n_leafs =*/ 0,
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.work_size =*/ 0,
|
||||
/*.work =*/ NULL,
|
||||
/*.nodes =*/ { NULL },
|
||||
/*.grads =*/ { NULL },
|
||||
/*.leafs =*/ { NULL },
|
||||
@ -15945,12 +15943,13 @@ void clear_numa_thread_affinity(void) {}
|
||||
#endif
|
||||
|
||||
struct ggml_compute_state_shared {
|
||||
struct ggml_cgraph * cgraph;
|
||||
const struct ggml_cgraph * cgraph;
|
||||
const struct ggml_cplan * cplan;
|
||||
|
||||
int64_t perf_node_start_cycles;
|
||||
int64_t perf_node_start_time_us;
|
||||
|
||||
int n_threads;
|
||||
const int n_threads;
|
||||
|
||||
// synchronization primitives
|
||||
atomic_int n_active; // num active threads
|
||||
@ -15974,9 +15973,13 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
|
||||
|
||||
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||
struct ggml_cgraph * cgraph = state->shared->cgraph;
|
||||
|
||||
const int n_threads = state->shared->n_threads;
|
||||
const struct ggml_cgraph * cgraph = state->shared->cgraph;
|
||||
const struct ggml_cplan * cplan = state->shared->cplan;
|
||||
|
||||
const int * n_tasks_arr = cplan->n_tasks;
|
||||
const int n_threads = state->shared->n_threads;
|
||||
|
||||
set_numa_thread_affinity(state->ith, n_threads);
|
||||
|
||||
int node_n = -1;
|
||||
@ -15989,15 +15992,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/*.type =*/ GGML_TASK_FINALIZE,
|
||||
/*.ith =*/ 0,
|
||||
/*.nth =*/ 0,
|
||||
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
||||
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
|
||||
/*.wsize =*/ cplan->work_size,
|
||||
/*.wdata =*/ cplan->work_data,
|
||||
};
|
||||
|
||||
if (node_n != -1) {
|
||||
/* FINALIZE */
|
||||
struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
|
||||
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
||||
params.nth = node->n_tasks;
|
||||
params.nth = n_tasks_arr[node_n];
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_graph_compute_perf_stats_node(node, state->shared);
|
||||
}
|
||||
@ -16008,11 +16011,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
|
||||
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
const int n_tasks = n_tasks_arr[node_n];
|
||||
|
||||
state->shared->perf_node_start_cycles = ggml_perf_cycles();
|
||||
state->shared->perf_node_start_time_us = ggml_perf_time_us();
|
||||
|
||||
params.nth = node->n_tasks;
|
||||
params.nth = n_tasks;
|
||||
|
||||
/* INIT */
|
||||
if (GGML_OP_HAS_INIT[node->op]) {
|
||||
@ -16020,7 +16024,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
ggml_compute_forward(¶ms, node);
|
||||
}
|
||||
|
||||
if (node->n_tasks == 1) {
|
||||
if (n_tasks == 1) {
|
||||
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
|
||||
// they do something more efficient than spinning (?)
|
||||
params.type = GGML_TASK_COMPUTE;
|
||||
@ -16042,7 +16046,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
// wait for other threads to finish
|
||||
const int last = node_n;
|
||||
do {
|
||||
sched_yield();
|
||||
//sched_yield();
|
||||
node_n = atomic_load(&state->shared->node_n);
|
||||
} while (node_n == last);
|
||||
}
|
||||
@ -16052,16 +16056,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
|
||||
/* COMPUTE */
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
const int n_tasks = n_tasks_arr[node_n];
|
||||
|
||||
struct ggml_compute_params params = {
|
||||
/*.type =*/ GGML_TASK_COMPUTE,
|
||||
/*.ith =*/ state->ith,
|
||||
/*.nth =*/ node->n_tasks,
|
||||
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
||||
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
|
||||
/*.nth =*/ n_tasks,
|
||||
/*.wsize =*/ cplan->work_size,
|
||||
/*.wdata =*/ cplan->work_data,
|
||||
};
|
||||
|
||||
if (state->ith < node->n_tasks) {
|
||||
if (state->ith < n_tasks) {
|
||||
ggml_compute_forward(¶ms, node);
|
||||
}
|
||||
}
|
||||
@ -16069,11 +16074,364 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
|
||||
const int n_threads = cgraph->n_threads;
|
||||
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
if (n_threads <= 0) {
|
||||
n_threads = GGML_DEFAULT_N_THREADS;
|
||||
}
|
||||
|
||||
size_t work_size = 0;
|
||||
|
||||
struct ggml_cplan cplan;
|
||||
memset(&cplan, 0, sizeof(struct ggml_cplan));
|
||||
|
||||
// thread scheduling for the different operations + work buffer size estimation
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
int n_tasks = 1;
|
||||
|
||||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
if (ggml_is_quantized(node->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (ggml_is_quantized(node->src0->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_ACC:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (ggml_is_quantized(node->src0->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_ABS:
|
||||
case GGML_OP_SGN:
|
||||
case GGML_OP_NEG:
|
||||
case GGML_OP_STEP:
|
||||
case GGML_OP_TANH:
|
||||
case GGML_OP_ELU:
|
||||
case GGML_OP_RELU:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_GELU:
|
||||
case GGML_OP_GELU_QUICK:
|
||||
case GGML_OP_SILU:
|
||||
case GGML_OP_SILU_BACK:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_OUT_PROD:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
// TODO: use different scheduling for different matrix sizes
|
||||
//const int nr0 = ggml_nrows(node->src0);
|
||||
//const int nr1 = ggml_nrows(node->src1);
|
||||
|
||||
//n_tasks = MIN(n_threads, MAX(1, nr0/128));
|
||||
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
|
||||
|
||||
size_t cur = 0;
|
||||
const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
||||
n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
} else
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
|
||||
n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||
} else
|
||||
#endif
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||
n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
if (node->src0->type != GGML_TYPE_F32) {
|
||||
// here we need memory just for single 2D matrix from src0
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
if (node->src1->type != vec_dot_type) {
|
||||
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
|
||||
} else {
|
||||
cur = 0;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_DIAG_MASK_ZERO:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ALIBI:
|
||||
{
|
||||
n_tasks = 1; //TODO
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
n_tasks = 1; //TODO
|
||||
} break;
|
||||
case GGML_OP_CONV_1D:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
GGML_ASSERT(node->src0->ne[3] == 1);
|
||||
GGML_ASSERT(node->src1->ne[2] == 1);
|
||||
GGML_ASSERT(node->src1->ne[3] == 1);
|
||||
|
||||
size_t cur = 0;
|
||||
const int nk = node->src0->ne[0];
|
||||
|
||||
if (node->src0->type == GGML_TYPE_F16 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(ggml_fp16_t)*(
|
||||
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
||||
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
||||
);
|
||||
} else if (node->src0->type == GGML_TYPE_F32 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*(
|
||||
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
||||
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
||||
);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_CONV_2D:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
GGML_ASSERT(node->src1->ne[3] == 1);
|
||||
|
||||
const int64_t ne00 = node->src0->ne[0]; // W
|
||||
const int64_t ne01 = node->src0->ne[1]; // H
|
||||
const int64_t ne02 = node->src0->ne[2]; // C
|
||||
const int64_t ne03 = node->src0->ne[3]; // N
|
||||
|
||||
const int64_t ne10 = node->src1->ne[0]; // W
|
||||
const int64_t ne11 = node->src1->ne[1]; // H
|
||||
const int64_t ne12 = node->src1->ne[2]; // C
|
||||
|
||||
const int64_t nk = ne00*ne01;
|
||||
|
||||
UNUSED(ne02);
|
||||
UNUSED(ne03);
|
||||
UNUSED(nk);
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (node->src0->type == GGML_TYPE_F16 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
|
||||
} else if (node->src0->type == GGML_TYPE_F32 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)* (ne10*ne11*ne12);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F16) {
|
||||
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_FLASH_FF:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F16) {
|
||||
cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
const int64_t D = node->src0->ne[0];
|
||||
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
||||
if (node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F16) {
|
||||
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
case GGML_OP_MAP_CUSTOM2:
|
||||
case GGML_OP_MAP_CUSTOM3:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks);
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks;
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_NONE:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
|
||||
cplan.n_tasks[i] = n_tasks;
|
||||
}
|
||||
|
||||
if (work_size > 0) {
|
||||
work_size += CACHE_LINE_SIZE*(n_threads - 1);
|
||||
}
|
||||
|
||||
cplan.n_threads = n_threads;
|
||||
cplan.work_size = work_size;
|
||||
cplan.work_data = NULL;
|
||||
|
||||
return cplan;
|
||||
}
|
||||
|
||||
void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
|
||||
{
|
||||
GGML_ASSERT(cplan);
|
||||
GGML_ASSERT(cplan->n_threads > 0);
|
||||
|
||||
if (cplan->work_size > 0) {
|
||||
GGML_ASSERT(cplan->work_data);
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
if (cgraph->nodes[i]->op != GGML_OP_NONE) {
|
||||
GGML_ASSERT(cplan->n_tasks[i] > 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int n_threads = cplan->n_threads;
|
||||
|
||||
struct ggml_compute_state_shared state_shared = {
|
||||
/*.cgraph =*/ cgraph,
|
||||
/*.cgraph_plan =*/ cplan,
|
||||
/*.perf_node_start_cycles =*/ 0,
|
||||
/*.perf_node_start_time_us =*/ 0,
|
||||
/*.n_threads =*/ n_threads,
|
||||
@ -16082,336 +16440,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||
};
|
||||
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
|
||||
|
||||
// initialize tasks + work buffer
|
||||
{
|
||||
size_t work_size = 0;
|
||||
|
||||
// thread scheduling for the different operations
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
if (ggml_is_quantized(node->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (ggml_is_quantized(node->src0->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_ACC:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (ggml_is_quantized(node->src0->type)) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_ABS:
|
||||
case GGML_OP_SGN:
|
||||
case GGML_OP_NEG:
|
||||
case GGML_OP_STEP:
|
||||
case GGML_OP_TANH:
|
||||
case GGML_OP_ELU:
|
||||
case GGML_OP_RELU:
|
||||
{
|
||||
node->n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_GELU:
|
||||
case GGML_OP_GELU_QUICK:
|
||||
case GGML_OP_SILU:
|
||||
case GGML_OP_SILU_BACK:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_OUT_PROD:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
// TODO: use different scheduling for different matrix sizes
|
||||
//const int nr0 = ggml_nrows(node->src0);
|
||||
//const int nr1 = ggml_nrows(node->src1);
|
||||
|
||||
//node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
|
||||
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
|
||||
|
||||
size_t cur = 0;
|
||||
const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
}
|
||||
else
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
if (node->src0->type != GGML_TYPE_F32) {
|
||||
// here we need memory just for single 2D matrix from src0
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
if (node->src1->type != vec_dot_type) {
|
||||
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
|
||||
} else {
|
||||
cur = 0;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
node->n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_DIAG_MASK_ZERO:
|
||||
{
|
||||
node->n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ALIBI:
|
||||
{
|
||||
node->n_tasks = 1; //TODO
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
node->n_tasks = 1; //TODO
|
||||
} break;
|
||||
case GGML_OP_CONV_1D:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
GGML_ASSERT(node->src0->ne[3] == 1);
|
||||
GGML_ASSERT(node->src1->ne[2] == 1);
|
||||
GGML_ASSERT(node->src1->ne[3] == 1);
|
||||
|
||||
size_t cur = 0;
|
||||
const int nk = node->src0->ne[0];
|
||||
|
||||
if (node->src0->type == GGML_TYPE_F16 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(ggml_fp16_t)*(
|
||||
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
||||
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
||||
);
|
||||
} else if (node->src0->type == GGML_TYPE_F32 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*(
|
||||
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
||||
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
||||
);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_CONV_2D:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
GGML_ASSERT(node->src1->ne[3] == 1);
|
||||
|
||||
const int64_t ne00 = node->src0->ne[0]; // W
|
||||
const int64_t ne01 = node->src0->ne[1]; // H
|
||||
const int64_t ne02 = node->src0->ne[2]; // C
|
||||
const int64_t ne03 = node->src0->ne[3]; // N
|
||||
|
||||
const int64_t ne10 = node->src1->ne[0]; // W
|
||||
const int64_t ne11 = node->src1->ne[1]; // H
|
||||
const int64_t ne12 = node->src1->ne[2]; // C
|
||||
|
||||
const int64_t nk = ne00*ne01;
|
||||
|
||||
UNUSED(ne02);
|
||||
UNUSED(ne03);
|
||||
UNUSED(nk);
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (node->src0->type == GGML_TYPE_F16 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
|
||||
} else if (node->src0->type == GGML_TYPE_F32 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)* (ne10*ne11*ne12);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F16) {
|
||||
cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_FLASH_FF:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F16) {
|
||||
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
const int64_t D = node->src0->ne[0];
|
||||
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
||||
if (node->src1->type == GGML_TYPE_F32) {
|
||||
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
if (node->src1->type == GGML_TYPE_F16) {
|
||||
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
case GGML_OP_MAP_CUSTOM2:
|
||||
case GGML_OP_MAP_CUSTOM3:
|
||||
{
|
||||
node->n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_NONE:
|
||||
{
|
||||
node->n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
if (cgraph->work != NULL && work_size > cgraph->work_size) {
|
||||
GGML_ASSERT(false); // TODO: better handling
|
||||
}
|
||||
|
||||
if (work_size > 0 && cgraph->work == NULL) {
|
||||
cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
|
||||
|
||||
GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
|
||||
cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
|
||||
}
|
||||
}
|
||||
|
||||
// create thread pool
|
||||
if (n_threads > 1) {
|
||||
for (int j = 1; j < n_threads; ++j) {
|
||||
@ -16473,6 +16501,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
|
||||
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
|
||||
|
||||
struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
|
||||
GGML_ASSERT(buf);
|
||||
|
||||
cplan.work_data = buf->data;
|
||||
|
||||
ggml_graph_compute(cgraph, &cplan);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
struct ggml_tensor * leaf = cgraph->leafs[i];
|
||||
@ -16511,14 +16550,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
||||
const int64_t * ne = tensor->ne;
|
||||
const size_t * nb = tensor->nb;
|
||||
|
||||
fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
|
||||
fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
|
||||
arg,
|
||||
ggml_type_name(tensor->type),
|
||||
ggml_op_name (tensor->op),
|
||||
tensor->n_dims,
|
||||
ne[0], ne[1], ne[2], ne[3],
|
||||
nb[0], nb[1], nb[2], nb[3],
|
||||
tensor->n_tasks,
|
||||
tensor->data,
|
||||
tensor->name);
|
||||
}
|
||||
@ -17254,9 +17292,6 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||
struct ggml_cgraph * gb) {
|
||||
GGML_ASSERT(ggml_is_scalar(f));
|
||||
|
||||
gf->n_threads = params.n_threads;
|
||||
gb->n_threads = params.n_threads;
|
||||
|
||||
// these will store the parameters we want to optimize
|
||||
struct ggml_tensor * ps[GGML_MAX_PARAMS];
|
||||
|
||||
@ -17303,7 +17338,8 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||
// compute the function value
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx, gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
|
||||
|
||||
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
||||
opt->adam.fx_best = opt->adam.fx_prev;
|
||||
@ -17383,7 +17419,8 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx, gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
|
||||
|
||||
const float fx = ggml_get_f32_1d(f, 0);
|
||||
|
||||
@ -17505,7 +17542,8 @@ static enum ggml_opt_result linesearch_backtracking(
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx, gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
|
||||
|
||||
ggml_opt_get_grad(np, ps, g);
|
||||
|
||||
@ -17573,9 +17611,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
||||
}
|
||||
}
|
||||
|
||||
gf->n_threads = params.n_threads;
|
||||
gb->n_threads = params.n_threads;
|
||||
|
||||
const int m = params.lbfgs.m;
|
||||
|
||||
// these will store the parameters we want to optimize
|
||||
@ -17627,7 +17662,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx, gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
|
||||
|
||||
ggml_opt_get_grad(np, ps, g);
|
||||
|
||||
|
36
ggml.h
36
ggml.h
@ -65,7 +65,7 @@
|
||||
// ggml_set_f32(a, 3.0f);
|
||||
// ggml_set_f32(b, 4.0f);
|
||||
//
|
||||
// ggml_graph_compute(ctx0, &gf);
|
||||
// ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
|
||||
//
|
||||
// printf("f = %f\n", ggml_get_f32_1d(f, 0));
|
||||
//
|
||||
@ -418,9 +418,6 @@ extern "C" {
|
||||
struct ggml_tensor * src1;
|
||||
struct ggml_tensor * opt[GGML_MAX_OPT];
|
||||
|
||||
// thread scheduling
|
||||
int n_tasks;
|
||||
|
||||
// performance
|
||||
int perf_runs;
|
||||
int64_t perf_cycles;
|
||||
@ -432,19 +429,27 @@ extern "C" {
|
||||
|
||||
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||
|
||||
char padding[4];
|
||||
char padding[8];
|
||||
};
|
||||
|
||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||
|
||||
// the compute plan that needs to be prepared for ggml_graph_compute()
|
||||
// since https://github.com/ggerganov/ggml/issues/287
|
||||
struct ggml_cplan {
|
||||
size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
|
||||
uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
|
||||
|
||||
int n_threads;
|
||||
|
||||
// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
|
||||
int n_tasks[GGML_MAX_NODES];
|
||||
};
|
||||
|
||||
// computation graph
|
||||
struct ggml_cgraph {
|
||||
int n_nodes;
|
||||
int n_leafs;
|
||||
int n_threads;
|
||||
|
||||
size_t work_size;
|
||||
struct ggml_tensor * work;
|
||||
|
||||
struct ggml_tensor * nodes[GGML_MAX_NODES];
|
||||
struct ggml_tensor * grads[GGML_MAX_NODES];
|
||||
@ -1290,15 +1295,22 @@ extern "C" {
|
||||
|
||||
GGML_API void ggml_set_param(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * tensor);
|
||||
struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
||||
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
||||
|
||||
GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
||||
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
|
||||
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
||||
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
||||
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
|
||||
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
|
||||
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
|
||||
|
||||
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
||||
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
||||
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
|
||||
|
||||
|
60
llama.cpp
60
llama.cpp
@ -86,6 +86,25 @@ void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
|
||||
(void) tensor;
|
||||
}
|
||||
|
||||
//
|
||||
// ggml helpers
|
||||
//
|
||||
|
||||
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
}
|
||||
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
//
|
||||
// memory sizes
|
||||
//
|
||||
|
||||
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
|
||||
{
|
||||
static std::map<e_model, size_t> k_sizes = {
|
||||
@ -328,6 +347,9 @@ struct llama_context {
|
||||
// input embedding (1-dimensional array: [n_embd])
|
||||
std::vector<float> embedding;
|
||||
|
||||
// reusable buffer for `struct ggml_graph_plan.work_data`
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
// TODO: move in llama_state
|
||||
llama_ctx_buffer buf_compute;
|
||||
@ -768,7 +790,6 @@ struct llama_model_loader {
|
||||
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
@ -1284,17 +1305,11 @@ static bool llama_eval_internal(
|
||||
const float * embd,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads,
|
||||
int n_threads,
|
||||
const char * cgraph_fname) {
|
||||
|
||||
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
|
||||
|
||||
// enforce that the first token is BOS
|
||||
if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) {
|
||||
fprintf(stderr, "%s: first token must be BOS\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
const int N = n_tokens;
|
||||
@ -1325,10 +1340,11 @@ static bool llama_eval_internal(
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
|
||||
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
@ -1642,6 +1658,7 @@ static bool llama_eval_internal(
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (lctx.ctx_metal && N == 1) {
|
||||
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
|
||||
ggml_metal_get_tensor (lctx.ctx_metal, cur);
|
||||
} else {
|
||||
@ -1661,10 +1678,10 @@ static bool llama_eval_internal(
|
||||
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
|
||||
}
|
||||
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
||||
#endif
|
||||
|
||||
if (cgraph_fname) {
|
||||
@ -2626,8 +2643,8 @@ void llama_free_model(struct llama_model * model) {
|
||||
}
|
||||
|
||||
struct llama_context * llama_new_context_with_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params) {
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params) {
|
||||
|
||||
if (!model) {
|
||||
return nullptr;
|
||||
@ -2704,7 +2721,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
#ifdef GGML_USE_METAL
|
||||
if (params.n_gpu_layers > 0) {
|
||||
// this allocates all Metal resources and memory buffers
|
||||
ctx->ctx_metal = ggml_metal_init();
|
||||
ctx->ctx_metal = ggml_metal_init(1);
|
||||
|
||||
void * data_ptr = NULL;
|
||||
size_t data_size = 0;
|
||||
@ -2871,6 +2888,9 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
||||
// read tensors and apply
|
||||
bool warned = false;
|
||||
int n_tensors = 0;
|
||||
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
@ -3035,8 +3055,8 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
|
||||
}
|
||||
|
||||
struct ggml_cgraph gf = ggml_build_forward(r);
|
||||
gf.n_threads = n_threads;
|
||||
ggml_graph_compute(lora_ctx, &gf);
|
||||
|
||||
ggml_graph_compute_helper(work_buffer, &gf, n_threads);
|
||||
|
||||
// we won't need these tensors again, reset the context to save memory
|
||||
ggml_free(lora_ctx);
|
||||
@ -3189,7 +3209,6 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
||||
|
||||
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
|
||||
ggml_cgraph gf{};
|
||||
gf.n_threads = 1;
|
||||
|
||||
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||
kout3d->data = out;
|
||||
@ -3209,7 +3228,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
|
||||
ggml_graph_compute(cpy_ctx, &gf);
|
||||
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
|
||||
|
||||
ggml_free(cpy_ctx);
|
||||
}
|
||||
@ -3295,7 +3314,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
||||
|
||||
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
|
||||
ggml_cgraph gf{};
|
||||
gf.n_threads = 1;
|
||||
|
||||
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||
kin3d->data = (void *) inp;
|
||||
@ -3315,7 +3333,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
|
||||
ggml_graph_compute(cpy_ctx, &gf);
|
||||
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
|
||||
|
||||
ggml_free(cpy_ctx);
|
||||
}
|
||||
|
@ -10,5 +10,5 @@ llama_add_test(test-quantize-fns.cpp)
|
||||
llama_add_test(test-quantize-perf.cpp)
|
||||
llama_add_test(test-sampling.cpp)
|
||||
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
|
||||
# llama_add_test(test-grad0.c) # SLOW
|
||||
llama_add_test(test-grad0.c) # SLOW
|
||||
# llama_add_test(test-opt.c) # SLOW
|
||||
|
@ -10,6 +10,8 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wdouble-promotion"
|
||||
|
||||
#define MAX_NARGS 3
|
||||
|
||||
#undef MIN
|
||||
@ -49,7 +51,7 @@ float frand(void) {
|
||||
|
||||
int irand(int n) {
|
||||
if (n == 0) return 0;
|
||||
else return rand()%n;
|
||||
return rand()%n;
|
||||
}
|
||||
|
||||
void get_random_dims(int64_t * dims, int ndims) {
|
||||
@ -159,12 +161,14 @@ struct ggml_tensor * get_random_tensor_int(
|
||||
float get_element(const struct ggml_tensor * t, int idx) {
|
||||
if (t->type == GGML_TYPE_F32) {
|
||||
return ((float *)t->data)[idx];
|
||||
} else if (t->type == GGML_TYPE_I32) {
|
||||
return ((int32_t *)t->data)[idx];
|
||||
} else {
|
||||
assert(false);
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
return ((int32_t *)t->data)[idx];
|
||||
}
|
||||
|
||||
assert(false);
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
void set_element(struct ggml_tensor * t, int idx, float value) {
|
||||
@ -215,15 +219,14 @@ bool check_gradient(
|
||||
}
|
||||
|
||||
struct ggml_cgraph gf = ggml_build_forward (f);
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
|
||||
gb.n_threads = n_threads;
|
||||
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
ggml_graph_reset (&gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx0, &gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
||||
|
||||
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
|
||||
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
|
||||
@ -236,15 +239,16 @@ bool check_gradient(
|
||||
const float xm = x0 - eps;
|
||||
const float xp = x0 + eps;
|
||||
set_element(x[i], k, xp);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
const float f0 = ggml_get_f32_1d(f, 0);
|
||||
|
||||
set_element(x[i], k, xm);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
const float f1 = ggml_get_f32_1d(f, 0);
|
||||
|
||||
const float g0 = (f0 - f1)/(2.0f*eps);
|
||||
|
||||
set_element(x[i], k, x0);
|
||||
@ -252,12 +256,13 @@ bool check_gradient(
|
||||
// compute gradient using backward graph
|
||||
ggml_graph_reset (&gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx0, &gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
||||
|
||||
const float g1 = get_element(x[i]->grad, k);
|
||||
|
||||
const float error_abs = fabsf(g0 - g1);
|
||||
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
|
||||
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
|
||||
|
||||
if (error_abs > max_error_abs || error_rel > max_error_rel) {
|
||||
printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#define MAX_NARGS 2
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wdouble-promotion"
|
||||
|
||||
//
|
||||
// logging
|
||||
@ -33,7 +34,7 @@
|
||||
#define GGML_PRINT(...) printf(__VA_ARGS__)
|
||||
|
||||
|
||||
float frand() {
|
||||
float frand(void) {
|
||||
return (float)rand()/(float)RAND_MAX;
|
||||
}
|
||||
|
||||
@ -114,7 +115,7 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
|
||||
((float *)t->data)[idx] = value;
|
||||
}
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
int main(void) {
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = 1024*1024*1024,
|
||||
.mem_buffer = NULL,
|
||||
@ -137,10 +138,11 @@ int main(int argc, const char ** argv) {
|
||||
struct ggml_tensor * d = ggml_sub(ctx, c, ab);
|
||||
struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
|
||||
|
||||
|
||||
struct ggml_cgraph ge = ggml_build_forward(e);
|
||||
ggml_graph_reset (&ge);
|
||||
ggml_graph_compute(ctx, &ge);
|
||||
ggml_graph_reset(&ge);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
|
||||
|
||||
const float fe = ggml_get_f32_1d(e, 0);
|
||||
printf("%s: e = %.4f\n", __func__, fe);
|
||||
|
||||
@ -148,8 +150,10 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
ggml_opt(ctx, opt_params, e);
|
||||
|
||||
ggml_graph_reset (&ge);
|
||||
ggml_graph_compute(ctx, &ge);
|
||||
ggml_graph_reset(&ge);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
|
||||
|
||||
const float fe_opt = ggml_get_f32_1d(e, 0);
|
||||
printf("%s: original e = %.4f\n", __func__, fe);
|
||||
printf("%s: optimized e = %.4f\n", __func__, fe_opt);
|
||||
|
Loading…
Reference in New Issue
Block a user