Skip to content

Commit

Permalink
Use a single FiberRef for retrieving the Cache and identifying disa…
Browse files Browse the repository at this point in the history
…bled caching (#495)

* Use a single FiberRef for retrieving the Cache and identifying disabled caching

* Fix bin-compat issues
  • Loading branch information
kyri-petrou authored Jul 10, 2024
1 parent 7875186 commit a345a27
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 55 deletions.
2 changes: 2 additions & 0 deletions benchmarks/src/main/scala/zio/query/DataSourceBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DataSourceBenchmark {
@Benchmark
def fetchSumDuplicatedBenchmark(): Long = {
import FetchImpl._
import fetch.fetchM
type FIO[A] = Fetch[IO, A]

val reqs = (0 until count).toList.map(i => fetchPlusOne(1))
Expand All @@ -64,6 +65,7 @@ class DataSourceBenchmark {

@Benchmark
def fetchSumUniqueBenchmark(): Long = {
import fetch.fetchM
import FetchImpl._
type FIO[A] = Fetch[IO, A]

Expand Down
54 changes: 30 additions & 24 deletions zio-query/shared/src/main/scala/zio/query/ZQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
* [[memoize]] for memoizing the result of a single query
*/
def cached(implicit trace: Trace): ZQuery[R, E, A] =
ZQuery.acquireReleaseWith(ZQuery.cachingEnabled.getAndSet(true))(ZQuery.cachingEnabled.set)(_ => self)
ZQuery.unwrap(ZQuery.disabledCache.get.map {
case None => self
case s =>
val acq = ZQuery.disabledCache.set(None) *> ZQuery.currentCache.set(s)
val rel = ZQuery.disabledCache.set(s) *> ZQuery.currentCache.set(None)
ZQuery.acquireReleaseWith(acq)(_ => rel)(_ => self)
})

/**
* Recovers from all errors.
Expand Down Expand Up @@ -539,7 +545,7 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
ZIO.uninterruptibleMask { restore =>
ZIO.withFiberRuntime[R, E, A] { (state, _) =>
val scope = QueryScope.make()
state.setFiberRef(ZQuery.currentCache, cache)
state.setFiberRef(ZQuery.currentCache, Some(cache))
state.setFiberRef(ZQuery.currentScope, scope)
restore(runToZIO).exitWith { exit =>
state.deleteFiberRef(ZQuery.currentCache)
Expand Down Expand Up @@ -680,7 +686,13 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
* Disables caching for this query.
*/
def uncached(implicit trace: Trace): ZQuery[R, E, A] =
ZQuery.acquireReleaseWith(ZQuery.cachingEnabled.getAndSet(false))(ZQuery.cachingEnabled.set)(_ => self)
ZQuery.unwrap(ZQuery.currentCache.get.map {
case None => self
case s =>
val acq = ZQuery.disabledCache.set(s) *> ZQuery.currentCache.set(None)
val rel = ZQuery.disabledCache.set(None) *> ZQuery.currentCache.set(s)
ZQuery.acquireReleaseWith(acq)(_ => rel)(_ => self)
})

/**
* Converts a `ZQuery[R, Either[E, B], A]` into a `ZQuery[R, E, Either[A,
Expand Down Expand Up @@ -1451,13 +1463,9 @@ object ZQuery {
request0: => A
)(dataSource0: => DataSource[R, A])(implicit ev: A <:< Request[E, B], trace: Trace): ZQuery[R, E, B] =
ZQuery {
ZQuery.cachingEnabled.getWith { isCachingEnabled =>
val request = request0
val dataSource = dataSource0
if (isCachingEnabled)
ZQuery.currentCache.getWith(cachedResult(_, dataSource, request).toZIO)
else
uncachedResult(dataSource, request)
ZQuery.currentCache.getWith {
case Some(cache) => cachedResult(cache, dataSource0, request0).toZIO
case _ => uncachedResult(dataSource0, request0)
}
}

Expand Down Expand Up @@ -1512,12 +1520,9 @@ object ZQuery {
f: In => A
)(dataSource: DataSource[R, A])(implicit ev: A <:< Request[E, B], trace: Trace): ZQuery[R, E, Chunk[B]] =
ZQuery {
ZQuery.cachingEnabled.getWith {
if (_) {
ZQuery.currentCache.getWith(cache => CachedResult.foreach(as)(r => cachedResult(cache, dataSource, f(r))))
} else {
ZIO.foreach(as)(r => uncachedResult(dataSource, f(r)))
}
ZQuery.currentCache.getWith {
case Some(cache) => CachedResult.foreach(as)(r => cachedResult(cache, dataSource, f(r)))
case _ => ZIO.foreach(as)(r => uncachedResult(dataSource, f(r)))
}.map(collectResults(as, _, mode = 2))
}

Expand All @@ -1531,12 +1536,9 @@ object ZQuery {
f: In => A
)(dataSource: DataSource[R, A])(implicit ev: A <:< Request[E, B], trace: Trace): ZQuery[R, E, List[B]] =
ZQuery {
ZQuery.cachingEnabled.getWith {
if (_) {
ZQuery.currentCache.getWith(cache => CachedResult.foreach(as)(r => cachedResult(cache, dataSource, f(r))))
} else {
ZIO.foreach(as)(r => uncachedResult(dataSource, f(r)))
}
ZQuery.currentCache.getWith {
case Some(cache) => CachedResult.foreach(as)(r => cachedResult(cache, dataSource, f(r)))
case _ => ZIO.foreach(as)(r => uncachedResult(dataSource, f(r)))
}.map(collectResults(as, _, mode = 2))
}

Expand Down Expand Up @@ -1825,11 +1827,15 @@ object ZQuery {
(bs.result(), cs.result())
}

@deprecated("No longer used, kept for binary compatibility only", "0.7.4")
val cachingEnabled: FiberRef[Boolean] =
FiberRef.unsafe.make(true)(Unsafe.unsafe)

val currentCache: FiberRef[Cache] =
FiberRef.unsafe.make(Cache.unsafeMake())(Unsafe.unsafe)
val currentCache: FiberRef[Option[Cache]] =
FiberRef.unsafe.make(Option.empty[Cache])(Unsafe.unsafe)

private val disabledCache: FiberRef[Option[Cache]] =
FiberRef.unsafe.make(Option.empty[Cache])(Unsafe.unsafe)

val currentScope: FiberRef[QueryScope] =
FiberRef.unsafe.make[QueryScope](QueryScope.NoOp)(Unsafe.unsafe)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,42 +105,41 @@ private[query] sealed trait BlockedRequests[-R] { self =>
/**
* Executes all requests, submitting requests to each data source in parallel.
*/
def run(implicit trace: Trace): ZIO[R, Nothing, Unit] =
ZQuery.currentCache.getWith { cache =>
val flattened = BlockedRequests.flatten(self)
ZIO.foreachDiscard(flattened) { requestsByDataSource =>
ZIO.foreachParDiscard(requestsByDataSource.toIterable) { case (dataSource, sequential) =>
val requests = sequential.map(_.map(_.request))

dataSource
.runAll(requests)
.catchAllCause(cause =>
ZIO.succeed {
CompletedRequestMap.failCause(
requests.flatten.asInstanceOf[Chunk[Request[Any, Any]]],
cause
)
}
)
.flatMap { completedRequests =>
ZQuery.cachingEnabled.getWith {
val completedRequestsM = mutable.HashMap.from(completedRequests.underlying)
if (_) {
completePromises(dataSource, sequential) { req =>
// Pop the entry, and fallback to the immutable one if we already removed it
completedRequestsM.remove(req) orElse completedRequests.lookup(req)
}
// cache responses that were not requested but were completed by the DataSource
if (completedRequestsM.nonEmpty) cacheLeftovers(cache, completedRequestsM) else ZIO.unit
} else {
// No need to remove entries here since we don't need to know which ones we need to put in the cache
ZIO.succeed(completePromises(dataSource, sequential)(completedRequestsM.get))
def run(implicit trace: Trace): ZIO[R, Nothing, Unit] = {
val flattened = BlockedRequests.flatten(self)
ZIO.foreachDiscard(flattened) { requestsByDataSource =>
ZIO.foreachParDiscard(requestsByDataSource.toIterable) { case (dataSource, sequential) =>
val requests = sequential.map(_.map(_.request))

dataSource
.runAll(requests)
.catchAllCause(cause =>
Exit.succeed {
CompletedRequestMap.failCause(
requests.flatten.asInstanceOf[Chunk[Request[Any, Any]]],
cause
)
}
)
.flatMap { completedRequests =>
val completedRequestsM = mutable.HashMap.from(completedRequests.underlying)
ZQuery.currentCache.getWith {
case Some(cache) =>
completePromises(dataSource, sequential) { req =>
// Pop the entry, and fallback to the immutable one if we already removed it
completedRequestsM.remove(req) orElse completedRequests.lookup(req)
}
// cache responses that were not requested but were completed by the DataSource
if (completedRequestsM.nonEmpty) cacheLeftovers(cache, completedRequestsM) else Exit.unit
case _ => {
// No need to remove entries here since we don't need to know which ones we need to put in the cache
ZIO.succeed(completePromises(dataSource, sequential)(completedRequestsM.get))
}
}
}
}
}
}
}
}

private[query] object BlockedRequests {
Expand Down

0 comments on commit a345a27

Please sign in to comment.