From 637743a36bd1d983a578f18bfc36ebdfa1875947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Chantepie?= Date: Sat, 9 Jul 2022 19:38:34 +0200 Subject: [PATCH] Fix #457 - Handle SQL initialization error in Akka (#463) --- akka/src/main/scala/anorm/AkkaStream.scala | 25 +++++++++---- akka/src/test/scala-2.13+/AkkaCompat.scala | 5 +++ akka/src/test/scala-2.13-/AkkaCompat.scala | 5 +++ .../src/test/scala/anorm/AkkaStreamSpec.scala | 35 +++++++++++++++++-- build.sbt | 12 ++++++- 5 files changed, 71 insertions(+), 11 deletions(-) create mode 100644 akka/src/test/scala-2.13+/AkkaCompat.scala create mode 100644 akka/src/test/scala-2.13-/AkkaCompat.scala diff --git a/akka/src/main/scala/anorm/AkkaStream.scala b/akka/src/main/scala/anorm/AkkaStream.scala index 825ec13b..6a805acf 100644 --- a/akka/src/main/scala/anorm/AkkaStream.scala +++ b/akka/src/main/scala/anorm/AkkaStream.scala @@ -2,6 +2,8 @@ package anorm import java.sql.Connection +import scala.util.control.NonFatal + import scala.concurrent.{ Future, Promise } import akka.stream.Materializer @@ -114,13 +116,24 @@ object AkkaStream { override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[Int]) = { val result = Promise[Int]() + val logic = new GraphStageLogic(shape) with OutHandler { private var cursor: Option[Cursor] = None private var counter: Int = 0 + private def failWith(cause: Throwable): Unit = { + result.failure(cause) + fail(out, cause) + () + } + override def preStart(): Unit = { - resultSet = sql.unsafeResultSet(connection) - nextCursor() + try { + resultSet = sql.unsafeResultSet(connection) + nextCursor() + } catch { + case NonFatal(cause) => failWith(cause) + } } override def postStop() = release() @@ -152,10 +165,8 @@ object AkkaStream { nextCursor() } - case Failure(cause) => { - result.failure(cause) - fail(out, cause) - } + case Failure(cause) => + failWith(cause) } case _ => { @@ -172,7 +183,7 @@ object AkkaStream { setHandler(out, this) } - (logic, result.future) + logic -> result.future } } diff --git a/akka/src/test/scala-2.13+/AkkaCompat.scala b/akka/src/test/scala-2.13+/AkkaCompat.scala new file mode 100644 index 00000000..179b4c91 --- /dev/null +++ b/akka/src/test/scala-2.13+/AkkaCompat.scala @@ -0,0 +1,5 @@ +package anorm + +private[anorm] object AkkaCompat { + type Seq[T] = _root_.scala.collection.immutable.Seq[T] +} diff --git a/akka/src/test/scala-2.13-/AkkaCompat.scala b/akka/src/test/scala-2.13-/AkkaCompat.scala new file mode 100644 index 00000000..ca940312 --- /dev/null +++ b/akka/src/test/scala-2.13-/AkkaCompat.scala @@ -0,0 +1,5 @@ +package anorm + +private[anorm] object AkkaCompat { + type Seq[T] = _root_.scala.collection.Seq[T] +} diff --git a/akka/src/test/scala/anorm/AkkaStreamSpec.scala b/akka/src/test/scala/anorm/AkkaStreamSpec.scala index 2168fe70..d5afea9b 100644 --- a/akka/src/test/scala/anorm/AkkaStreamSpec.scala +++ b/akka/src/test/scala/anorm/AkkaStreamSpec.scala @@ -9,6 +9,7 @@ import scala.concurrent.duration._ import akka.stream.scaladsl.{ Keep, Sink, Source } +import acolyte.jdbc.QueryResult import acolyte.jdbc.AcolyteDSL.withQueryResult import acolyte.jdbc.Implicits._ import acolyte.jdbc.RowLists.stringList @@ -29,7 +30,9 @@ final class AkkaStreamSpec(implicit ee: ExecutionEnv) extends org.specs2.mutable "Akka Stream" should { "expose the query result as source" in assertAllStagesStopped { withQueryResult(stringList :+ "A" :+ "B" :+ "C") { implicit con => - AkkaStream.source(SQL"SELECT * FROM Test", SqlParser.scalar[String]).runWith(Sink.seq[String]) must beEqualTo( + AkkaStream + .source(SQL"SELECT * FROM Test", SqlParser.scalar[String]) + .runWith(Sink.seq[String]) must beTypedEqualTo( Seq("A", "B", "C") ).await(0, 5.seconds) } @@ -40,7 +43,7 @@ final class AkkaStreamSpec(implicit ee: ExecutionEnv) extends org.specs2.mutable AkkaStream .source(SQL"SELECT * FROM Test", SqlParser.scalar[String]) .toMat(Sink.ignore)(Keep.left) - .run() must beEqualTo(3).await(0, 3.seconds) + .run() must beTypedEqualTo(3).await(0, 3.seconds) } } @@ -79,7 +82,33 @@ final class AkkaStreamSpec(implicit ee: ExecutionEnv) extends org.specs2.mutable } } - "on failure" in (withQueryResult(stringList :+ "A" :+ "B" :+ "C")) { implicit con => + "on failed initialization" in { + import java.sql.SQLException + + withQueryResult(QueryResult.Nil) { implicit con => + val failingSql = new Sql { + import java.sql.PreparedStatement + + def unsafeStatement( + connection: Connection, + generatedColumn: String, + generatedColumns: AkkaCompat.Seq[String] + ): PreparedStatement = ??? + + def unsafeStatement(connection: Connection, getGeneratedKeys: Boolean): PreparedStatement = + throw new SQLException("Init failure") + + def resultSetOnFirstRow: Boolean = ??? + } + + val graph = source(failingSql, SqlParser.scalar[String]) + val mat = Source.fromGraph(graph).toMat(Sink.ignore)(Keep.left).run() + + mat must throwA[SQLException]("Init failure").awaitFor(3.seconds) + } + } + + "on failure" in withQueryResult(stringList :+ "A" :+ "B" :+ "C") { implicit con => assertAllStagesStopped { val rSet = run(Sink.reduce[String] { (_, _) => sys.error("Foo") }) diff --git a/build.sbt b/build.sbt index 62b57df5..d52ad93b 100644 --- a/build.sbt +++ b/build.sbt @@ -221,7 +221,17 @@ lazy val `anorm-akka` = (project in file("akka")) libraryDependencies ++= (acolyte +: specs2Test) ++ Seq( "com.typesafe.akka" %% "akka-stream-contrib" % akkaContribVer.value % Test ), - scalacOptions += "-P:silencer:globalFilters=deprecated" + scalacOptions += "-P:silencer:globalFilters=deprecated", + Test / unmanagedSourceDirectories ++= { + CrossVersion.partialVersion(scalaVersion.value) match { + case Some((2, n)) if n < 13 => + Seq((Test / sourceDirectory).value / "scala-2.13-") + + case _ => + Seq((Test / sourceDirectory).value / "scala-2.13+") + + } + } ) .dependsOn(`anorm-core`)