vulkan: fix group_norm (#10496)

Fix bad calculation of the end of the range. Add a backend test that
covers the bad case (taken from stable diffusion).

Fixes https://github.com/leejet/stable-diffusion.cpp/issues/439.
This commit is contained in:
Jeff Bolz 2024-11-26 09:45:05 -06:00 committed by GitHub
parent 45abe0f74e
commit 904109ed0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 3 deletions

View File

@ -7157,7 +7157,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const int32_t max_period = tensor->op_params[1]; const int32_t max_period = tensor->op_params[1];
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
} else if (tensor->op == GGML_OP_POOL_2D) { } else if (tensor->op == GGML_OP_POOL_2D) {
enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]); enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
const int32_t k0 = tensor->op_params[1]; const int32_t k0 = tensor->op_params[1];
const int32_t k1 = tensor->op_params[2]; const int32_t k1 = tensor->op_params[2];
const int32_t s0 = tensor->op_params[3]; const int32_t s0 = tensor->op_params[3];

View File

@ -19,7 +19,7 @@ void main() {
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint start = gl_WorkGroupID.x * group_size + tid; const uint start = gl_WorkGroupID.x * group_size + tid;
const uint end = start + group_size; const uint end = (gl_WorkGroupID.x + 1) * group_size;
tmp[tid] = 0.0f; tmp[tid] = 0.0f;

View File

@ -3796,7 +3796,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_upscale()); test_cases.emplace_back(new test_upscale());
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
test_cases.emplace_back(new test_upscale_ext()); test_cases.emplace_back(new test_upscale_ext());
test_cases.emplace_back(new test_group_norm()); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_arange()); test_cases.emplace_back(new test_arange());