From 83f28d8cdb2d20f5b31ce3f7144ca39bc9ac4861 Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Wed, 10 Jul 2024 18:38:29 +0300 Subject: [PATCH] Make `ZQuery#run` reentrant safe (#499) --- .../src/main/scala/zio/query/ZQuery.scala | 24 +++++++++++++----- .../src/test/scala/zio/query/ZQuerySpec.scala | 25 ++++++++++++++++++- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala index 90af1be5..28972d57 100644 --- a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala +++ b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala @@ -539,23 +539,35 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result * Returns an effect that models executing this query with the specified * cache. */ - def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] = + def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] = { + import ZQuery.{currentCache, currentScope} + + def setRef[V](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], newValue: V): V = { + val oldValue = state.getFiberRefOrNull(fiberRef) + state.setFiberRef(fiberRef, newValue) + oldValue + } + + def resetRef[V <: AnyRef](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], oldValue: V): Unit = + if (oldValue ne null) state.setFiberRef(fiberRef, oldValue) else state.deleteFiberRef(fiberRef) + asExitOrElse(null) match { case null => ZIO.uninterruptibleMask { restore => ZIO.withFiberRuntime[R, E, A] { (state, _) => - val scope = QueryScope.make() - state.setFiberRef(ZQuery.currentCache, Some(cache)) - state.setFiberRef(ZQuery.currentScope, scope) + val scope = QueryScope.make() + val oldCache = setRef(state, currentCache, Some(cache)) + val oldScope = setRef(state, currentScope, scope) restore(runToZIO).exitWith { exit => - state.deleteFiberRef(ZQuery.currentCache) - state.deleteFiberRef(ZQuery.currentScope) + resetRef(state, currentCache, oldCache) + resetRef(state, currentScope, oldScope) scope.closeAndExitWith(exit) } } } case exit => exit } + } /** * Returns an effect that models executing this query, returning the query diff --git a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala index 391d018d..44824c32 100644 --- a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala +++ b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala @@ -2,6 +2,7 @@ package zio.query import zio._ import zio.query.QueryAspect._ +import zio.query.internal.QueryScope import zio.test.Assertion._ import zio.test.TestAspect.{after, nonFlaky, silent} import zio.test.{TestClock, TestConsole, TestEnvironment, _} @@ -270,7 +271,7 @@ object ZQuerySpec extends ZIOBaseSpec { assert(log)(hasAt(0)(containsString("GetNameById(1)"))) && assert(log)(hasAt(0)(containsString("GetNameById(2)"))) && assert(log)(hasAt(1)(containsString("GetNameById(1)"))) - } @@ nonFlaky, + } @@ nonFlaky(10), suite("race")( test("race with never") { val query = ZQuery.never.race(ZQuery.succeed(())) @@ -370,6 +371,28 @@ object ZQuerySpec extends ZIOBaseSpec { value <- ref.get } yield assertTrue(value == 1, results.forall(_.isLeft)) } + ), + suite("run")( + test("cache is reentrant safe") { + val q = + for { + c1 <- ZQuery.fromZIO(ZQuery.currentCache.get) + _ <- ZQuery.fromZIO(ZQuery.succeed("foo").run) + c2 <- ZQuery.fromZIO(ZQuery.currentCache.get) + } yield (c1, c2) + + q.run.map { case (c1, c2) => assertTrue(c1.isDefined, c1 == c2) } + }, + test("scope is reentrant safe") { + val q = + for { + c1 <- ZQuery.fromZIO(ZQuery.currentScope.get) + _ <- ZQuery.fromZIO(ZQuery.succeed("foo").run) + c2 <- ZQuery.fromZIO(ZQuery.currentScope.get) + } yield (c1, c2) + + q.run.map { case (c1, c2) => assertTrue(c1 != QueryScope.NoOp, c1 == c2) } + } ) ) @@ silent