mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 21:07:06 +01:00
tests : more
This commit is contained in:
parent
abeaf0d90e
commit
c6c1132e5e
@ -137,7 +137,10 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
@ -505,7 +508,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute(
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
|
@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
|
5
ggml.c
5
ggml.c
@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
const int64_t D = neq0;
|
||||
const int64_t N = neq1;
|
||||
const int64_t P = nek1 - N;
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
GGML_ASSERT(P >= 0);
|
||||
|
||||
GGML_ASSERT(nbq0 == sizeof(float));
|
||||
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||
@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nek1 == N + P);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
float scale = 1.0f;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
|
||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
|
@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
test_cases.emplace_back(new test_attn(64, 32, 512, 8));
|
||||
test_cases.emplace_back(new test_attn(64, 32, 512, 7));
|
||||
test_cases.emplace_back(new test_attn(64, 32, 512, 1));
|
||||
test_cases.emplace_back(new test_attn(80, 32, 512, 8));
|
||||
test_cases.emplace_back(new test_attn(80, 32, 512, 7));
|
||||
test_cases.emplace_back(new test_attn(80, 32, 512, 1));
|
||||
test_cases.emplace_back(new test_attn(128, 32, 512, 8));
|
||||
test_cases.emplace_back(new test_attn(128, 32, 512, 7));
|
||||
test_cases.emplace_back(new test_attn(128, 32, 512, 1));
|
||||
|
||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1));
|
||||
for (int hs : { 64, 80, 96, 112, 128, 256, }) {
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, 2048, 4096, }) {
|
||||
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
||||
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(__SANITIZE_THREAD__)
|
||||
// FIXME: these tests use too much memory with thread sanitizer
|
||||
|
Loading…
Reference in New Issue
Block a user