From 7998b4c2ee5cd3ff44d89d1726c6d54d0e4f5197 Mon Sep 17 00:00:00 2001 From: Nat Wilson Date: Thu, 18 Mar 2021 10:49:37 -0700 Subject: [PATCH] Adds an example demonstrating issue #2330 --- .../main/scala/UnreliableInterruption.scala | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 core/jvm/src/main/scala/UnreliableInterruption.scala diff --git a/core/jvm/src/main/scala/UnreliableInterruption.scala b/core/jvm/src/main/scala/UnreliableInterruption.scala new file mode 100644 index 0000000000..840ce4823d --- /dev/null +++ b/core/jvm/src/main/scala/UnreliableInterruption.scala @@ -0,0 +1,67 @@ +import cats.effect.concurrent.{Deferred, Ref} +import cats.effect.{ExitCode, IO, IOApp} +import fs2.{Pipe, Pull, Stream} + +object UnreliableInterruption extends IOApp { + private def resume[A, B](mk: A => Stream[IO, B], checkpoint: B => A)(start: A): Stream[IO, B] = { + def go(s: Stream[IO, Either[Throwable, B]], watermark: A): Pull[IO, B, Unit] = s.pull.uncons1.flatMap { + case Some((Right(b), rest)) => Pull.output1(b) >> go(rest, checkpoint(b)) + case Some((Left(_), _)) => go(mk(watermark).attempt, watermark) + case None => go(mk(watermark).attempt, watermark) + } + + go(mk(start).attempt, start).stream + } + + // Interrupt the stream after a five items, up to a max number of times + private def interrupter[A](deferred: Deferred[IO, Unit], maxInterrupts: Int, interruptCount: Ref[IO, Int]): Pipe[IO, A, A] = { + input: Stream[IO, A] => + input.zipWithIndex + .evalTap { + case (_, 5) => interruptCount.getAndUpdate(_ + 1).flatMap { i => + if (i < maxInterrupts) { + deferred.complete(()) + } else IO.unit + } + case _ => IO.unit + } + .map(_._1) + .interruptWhen(deferred.get.attempt) + } + + + def run(args: List[String]): IO[ExitCode] = { + val stream: Int => Stream[IO, Int] = Stream.iterate(_)(_ + 1) + + val usuallyWorks = for { + interruptCount <- Ref.of[IO, Int](0) + msg <- resume[Int, Int]( + start => Stream.eval(Deferred[IO, Unit]).flatMap(d => stream(start).through(interrupter(d, 1, interruptCount))), + _ + 1 + )(0) + .take(1000) + .compile + .toList + .map(lst => s"${lst.size} should be 1000, meaning it restarted once") + } yield msg + + val usuallyDoesNot = for { + interruptCount <- Ref.of[IO, Int](0) + msg <- resume[Int, Int]( + start => Stream.eval(Deferred[IO, Unit]).flatMap(d => stream(start).through(interrupter(d, 10, interruptCount))), + _ + 1 + )(0) + .take(1000) + .compile + .toList + .map(lst => s"${lst.size} should be 1000, meaning it restarted 10 times (but it's not)") + } yield msg + + val output = for { + a <- usuallyWorks + b <- usuallyDoesNot + } yield List(a, b).mkString("\n") + + output.map(println).as(ExitCode.Success) + } +}