diff --git a/ggml.c b/ggml.c index 5025ec23b..d8f74f3ce 100644 --- a/ggml.c +++ b/ggml.c @@ -60,6 +60,9 @@ typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; +typedef atomic_int atomic_flag; + +#define ATOMIC_FLAG_INIT 0 static void atomic_store(atomic_int * ptr, LONG val) { InterlockedExchange(ptr, val); @@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { return atomic_fetch_add(ptr, -(dec)); } +static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { + return InterlockedExchange(ptr, 1); +} +static void atomic_flag_clear(atomic_flag * ptr) { + InterlockedExchange(ptr, 0); +} typedef HANDLE pthread_t; @@ -2883,24 +2892,20 @@ struct ggml_state { // global state static struct ggml_state g_state; -static atomic_int g_state_barrier = 0; +static atomic_flag g_state_critical = ATOMIC_FLAG_INIT; // barrier via spin lock inline static void ggml_critical_section_start(void) { - int processing = atomic_fetch_add(&g_state_barrier, 1); - - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); // TODO: reconsider this - processing = atomic_fetch_add(&g_state_barrier, 1); + while (atomic_flag_test_and_set(&g_state_critical)) { + // spin + sched_yield(); } } // TODO: make this somehow automatically executed // some sort of "sentry" mechanism inline static void ggml_critical_section_end(void) { - atomic_fetch_sub(&g_state_barrier, 1); + atomic_flag_clear(&g_state_critical); } #if defined(__gnu_linux__)