mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
tests : more
This commit is contained in:
parent
abeaf0d90e
commit
c6c1132e5e
@ -137,7 +137,10 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
@ -505,7 +508,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute(
|
|||||||
switch (ne00) {
|
switch (ne00) {
|
||||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
|
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
|
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
|
5
ggml.c
5
ggml.c
@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
|
|
||||||
const int64_t D = neq0;
|
const int64_t D = neq0;
|
||||||
const int64_t N = neq1;
|
const int64_t N = neq1;
|
||||||
const int64_t P = nek1 - N;
|
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == D);
|
GGML_ASSERT(ne0 == D);
|
||||||
GGML_ASSERT(ne2 == N);
|
GGML_ASSERT(ne2 == N);
|
||||||
GGML_ASSERT(P >= 0);
|
|
||||||
|
|
||||||
GGML_ASSERT(nbq0 == sizeof(float));
|
GGML_ASSERT(nbq0 == sizeof(float));
|
||||||
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||||
@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
GGML_ASSERT(nev0 == D);
|
GGML_ASSERT(nev0 == D);
|
||||||
|
|
||||||
GGML_ASSERT(neq1 == N);
|
GGML_ASSERT(neq1 == N);
|
||||||
GGML_ASSERT(nek1 == N + P);
|
|
||||||
GGML_ASSERT(nev0 == D);
|
GGML_ASSERT(nev0 == D);
|
||||||
|
|
||||||
// dst cannot be transposed or permuted
|
// dst cannot be transposed or permuted
|
||||||
@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
|
|
||||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
|
||||||
|
|
||||||
// loop over n_batch and n_head
|
// loop over n_batch and n_head
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// q indices
|
// q indices
|
||||||
|
@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
test_cases.emplace_back(new test_attn(64, 32, 512, 8));
|
for (int hs : { 64, 80, 96, 112, 128, 256, }) {
|
||||||
test_cases.emplace_back(new test_attn(64, 32, 512, 7));
|
for (int nh : { 32, }) {
|
||||||
test_cases.emplace_back(new test_attn(64, 32, 512, 1));
|
for (int kv : { 512, 1024, 2048, 4096, }) {
|
||||||
test_cases.emplace_back(new test_attn(80, 32, 512, 8));
|
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
||||||
test_cases.emplace_back(new test_attn(80, 32, 512, 7));
|
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
|
||||||
test_cases.emplace_back(new test_attn(80, 32, 512, 1));
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
||||||
test_cases.emplace_back(new test_attn(128, 32, 512, 8));
|
}
|
||||||
test_cases.emplace_back(new test_attn(128, 32, 512, 7));
|
}
|
||||||
test_cases.emplace_back(new test_attn(128, 32, 512, 1));
|
}
|
||||||
|
}
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7));
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1));
|
|
||||||
|
|
||||||
#if !defined(__SANITIZE_THREAD__)
|
#if !defined(__SANITIZE_THREAD__)
|
||||||
// FIXME: these tests use too much memory with thread sanitizer
|
// FIXME: these tests use too much memory with thread sanitizer
|
||||||
|
Loading…
Reference in New Issue
Block a user