mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
9c67c2773d
* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com> |
||
---|---|---|
.. | ||
batched-bench.cpp | ||
CMakeLists.txt | ||
README.md |
llama.cpp/example/batched-bench
Benchmark the batched decoding performance of llama.cpp
Usage
There are 2 modes of operation:
prompt not shared
- each batch has a separate prompt of sizePP
(i.e.N_KV = B*(PP + TG)
)prompt is shared
- there is a common prompt of sizePP
used by all batches (i.e.N_KV = PP + B*TG
)
./batched-bench MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] [MMQ] <PP> <TG> <PL>
# LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared
./batched-bench ./models/llama-7b/ggml-model-f16.gguf 16384 2048 512 0 99
# LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), prompt is shared
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 16384 2048 512 1 99
# custom set of batches
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 2048 512 512 0 999 0 128,256,512 128,256 1,2,4,8,16,32
Sample results
PP
- prompt tokens per batchTG
- generated tokens per batchB
- number of batchesN_KV
- required KV cache sizeT_PP
- prompt processing time (i.e. time to first token)S_PP
- prompt processing speed ((B*PP)/T_PP
orPP/T_PP
)T_TG
- time to generate all batchesS_TG
- text generation speed ((B*TG)/T_TG
)T
- total timeS
- total speed (i.e. all tokens / total time)
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
128 | 128 | 1 | 256 | 0.108 | 1186.64 | 3.079 | 41.57 | 3.187 | 80.32 |
128 | 128 | 2 | 512 | 0.198 | 1295.19 | 5.029 | 50.90 | 5.227 | 97.95 |
128 | 128 | 4 | 1024 | 0.373 | 1373.96 | 6.878 | 74.44 | 7.251 | 141.23 |
128 | 128 | 8 | 2048 | 0.751 | 1363.27 | 7.344 | 139.43 | 8.095 | 252.99 |
128 | 128 | 16 | 4096 | 1.570 | 1304.68 | 8.455 | 242.23 | 10.024 | 408.60 |
128 | 128 | 32 | 8192 | 3.408 | 1201.73 | 8.801 | 465.40 | 12.209 | 670.96 |
128 | 256 | 1 | 384 | 0.107 | 1196.70 | 6.329 | 40.45 | 6.436 | 59.67 |
128 | 256 | 2 | 768 | 0.194 | 1317.45 | 10.239 | 50.00 | 10.433 | 73.61 |
128 | 256 | 4 | 1536 | 0.366 | 1399.03 | 13.960 | 73.35 | 14.326 | 107.22 |
128 | 256 | 8 | 3072 | 0.751 | 1363.92 | 15.110 | 135.54 | 15.861 | 193.69 |
128 | 256 | 16 | 6144 | 1.569 | 1304.93 | 18.073 | 226.64 | 19.642 | 312.80 |
128 | 256 | 32 | 12288 | 3.409 | 1201.35 | 19.223 | 426.15 | 22.633 | 542.93 |