From 5f17a0be3bea8dd1af259939ce4661d893dca3d7 Mon Sep 17 00:00:00 2001 From: Simon Vergauwen Date: Mon, 5 Dec 2022 15:46:01 +0100 Subject: [PATCH] Add CyclicBarrier (#2857) * Remove @Throws CountDownLatch --- .../api/arrow-fx-coroutines.api | 6 ++ .../arrow/fx/coroutines/CountDownLatch.kt | 2 +- .../arrow/fx/coroutines/CyclicBarrier.kt | 52 ++++++++++++ .../arrow/fx/coroutines/CyclicBarrierSpec.kt | 84 +++++++++++++++++++ 4 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt create mode 100644 arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt diff --git a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api index ee29ae2aac7..874a45f7b4d 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api +++ b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api @@ -104,6 +104,12 @@ public final class arrow/fx/coroutines/CountDownLatch { public final fun countDown ()V } +public final class arrow/fx/coroutines/CyclicBarrier { + public fun (I)V + public final fun await (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getCapacity ()I +} + public abstract class arrow/fx/coroutines/ExitCase { public static final field Companion Larrow/fx/coroutines/ExitCase$Companion; } diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CountDownLatch.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CountDownLatch.kt index 119c095d17f..ba7c2b2a28c 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CountDownLatch.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CountDownLatch.kt @@ -11,7 +11,7 @@ import kotlinx.coroutines.CompletableDeferred * Must be initialised with an [initial] value of 1 or higher, * if constructed with 0 or negative value then it throws [IllegalArgumentException]. */ -public class CountDownLatch @Throws(IllegalArgumentException::class) constructor(private val initial: Long) { +public class CountDownLatch(private val initial: Long) { private val signal = CompletableDeferred() private val count = AtomicRef(initial) diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt new file mode 100644 index 00000000000..f1d81ab18ee --- /dev/null +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt @@ -0,0 +1,52 @@ +package arrow.fx.coroutines + +import arrow.core.continuations.AtomicRef +import arrow.core.continuations.loop +import arrow.core.continuations.update +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred + +/** + * A [CyclicBarrier] is a synchronization mechanism that allows a set of coroutines to wait for each other + * to reach a certain point before continuing execution. + * It is called a "cyclic" barrier because it can be reused after all coroutines have reached the barrier and released. + * + * To use a CyclicBarrier, each coroutine must call the [await] method on the barrier object, + * which will cause the coroutine to suspend until the required number of coroutines have reached the barrier. + * Once all coroutines have reached the barrier they will _resume_ execution. + * + * Models the behavior of java.util.concurrent.CyclicBarrier in Kotlin with `suspend`. + */ +public class CyclicBarrier(public val capacity: Int) { + init { + require(capacity > 0) { + "Cyclic barrier must be constructed with positive non-zero capacity $capacity but was $capacity > 0" + } + } + + private data class State(val awaiting: Int, val epoch: Long, val unblock: CompletableDeferred) + + private val state: AtomicRef = AtomicRef(State(capacity, 0, CompletableDeferred())) + + /** + * When [await] is called the function will suspend until the required number of coroutines have reached the barrier. + * Once the [capacity] of the barrier has been reached, the coroutine will be released and continue execution. + */ + public suspend fun await() { + state.loop { original -> + val (awaiting, epoch, unblock) = original + val awaitingNow = awaiting - 1 + if (awaitingNow == 0 && state.compareAndSet(original, State(capacity, epoch + 1, CompletableDeferred()))) { + unblock.complete(Unit) + return + } else if (state.compareAndSet(original, State(awaitingNow, epoch, unblock))) { + return try { + unblock.await() + } catch (cancelled: CancellationException) { + state.update { s -> if (s.epoch == epoch) s.copy(awaiting = s.awaiting + 1) else s } + throw cancelled + } + } + } + } +} diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt new file mode 100644 index 00000000000..ea723a38b17 --- /dev/null +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt @@ -0,0 +1,84 @@ +package arrow.fx.coroutines + +import arrow.core.Either +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeTypeOf +import io.kotest.property.Arb +import io.kotest.property.arbitrary.constant +import io.kotest.property.arbitrary.int +import io.kotest.property.checkAll +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.launch + +class CyclicBarrierSpec : StringSpec({ + "should raise an exception when constructed with a negative or zero capacity" { + checkAll(Arb.int(Int.MIN_VALUE, 0)) { i -> + shouldThrow { CyclicBarrier(i) }.message shouldBe + "Cyclic barrier must be constructed with positive non-zero capacity $i but was $i > 0" + } + } + + "barrier of capacity 1 is a no op" { + checkAll(Arb.constant(Unit)) { + val barrier = CyclicBarrier(1) + barrier.await() + } + } + + "awaiting all in parallel resumes all coroutines" { + checkAll(Arb.int(1, 100)) { i -> + val barrier = CyclicBarrier(i) + (0 until i).parTraverse { barrier.await() } + } + } + + "should reset once full" { + checkAll(Arb.constant(Unit)) { + val barrier = CyclicBarrier(2) + parZip({ barrier.await() }, { barrier.await() }) { _, _ -> } + barrier.capacity shouldBe 2 + } + } + + "await is cancelable" { + checkAll(Arb.int(2, Int.MAX_VALUE)) { i -> + val barrier = CyclicBarrier(i) + val exitCase = CompletableDeferred() + + val job = + launch(start = CoroutineStart.UNDISPATCHED) { + guaranteeCase({ barrier.await() }, exitCase::complete) + } + + job.cancelAndJoin() + exitCase.isCompleted shouldBe true + exitCase.await().shouldBeTypeOf() + } + } + + "should clean up upon cancelation of await" { + checkAll(Arb.constant(Unit)) { + val barrier = CyclicBarrier(2) + launch(start = CoroutineStart.UNDISPATCHED) { barrier.await() }.cancelAndJoin() + + barrier.capacity shouldBe 2 + } + } + + "race fiber cancel and barrier full" { + checkAll(Arb.constant(Unit)) { + val barrier = CyclicBarrier(2) + val job = launch(start = CoroutineStart.UNDISPATCHED) { barrier.await() } + when (raceN({ barrier.await() }, { job.cancelAndJoin() })) { + // without the epoch check in CyclicBarrier, a late cancellation would increment the count + // after the barrier has already reset, causing this code to never terminate (test times out) + is Either.Left -> parZip({ barrier.await() }, { barrier.await() }) { _, _ -> } + is Either.Right -> Unit + } + } + } +})