Skip to content

Commit

Permalink
Handle stream completion (propagate failures) (#340)
Browse files Browse the repository at this point in the history
* Handle stream completion (propagate failures)


Co-authored-by: Arnout Engelen <[email protected]>
  • Loading branch information
ignasi35 and raboof authored May 11, 2019
1 parent 8fdb495 commit f99dbb2
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 62 deletions.
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ lazy val mimaSettings = mimaDefaultSettings ++ Seq(
mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("play.libs.ws.ahc.StandaloneAhcWSResponse.getBodyAsSource"),
ProblemFilters.exclude[MissingClassProblem]("play.api.libs.ws.package$"),
ProblemFilters.exclude[MissingClassProblem]("play.api.libs.ws.package")
ProblemFilters.exclude[MissingClassProblem]("play.api.libs.ws.package"),
ProblemFilters.exclude[DirectMissingMethodProblem]("play.api.libs.ws.ahc.DefaultStreamedAsyncHandler.this")
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,36 @@

package play.libs.ws.ahc;

import akka.Done;
import akka.stream.Materializer;
import akka.stream.javadsl.Source;
import akka.util.ByteString;
import akka.util.ByteStringBuilder;
import com.typesafe.sslconfig.ssl.SystemConfiguration;
import com.typesafe.sslconfig.ssl.debug.DebugConfiguration;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.LoggerFactory;
import play.api.libs.ws.ahc.AhcConfigBuilder;
import play.api.libs.ws.ahc.AhcLoggerFactory;
import play.api.libs.ws.ahc.AhcWSClientConfig;
import play.api.libs.ws.ahc.DefaultStreamedAsyncHandler;
import play.api.libs.ws.ahc.*;
import play.api.libs.ws.ahc.cache.AhcHttpCache;
import play.api.libs.ws.ahc.cache.CachingAsyncHttpClient;
import play.libs.ws.StandaloneWSClient;
import play.libs.ws.StandaloneWSResponse;
import play.shaded.ahc.org.asynchttpclient.*;
import scala.Function1;
import scala.compat.java8.FutureConverters;
import scala.compat.java8.FutureConverters$;
import scala.concurrent.ExecutionContext;
import scala.concurrent.Future;
import scala.concurrent.Promise;
import scala.util.Try;

import javax.inject.Inject;
import java.io.IOException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;

/**
* A WS asyncHttpClient backed by an AsyncHttpClient instance.
Expand Down Expand Up @@ -92,33 +97,73 @@ public void onThrowable(Throwable t) {
}

CompletionStage<StandaloneWSResponse> executeStream(Request request, ExecutionContext ec) {
final Promise<StandaloneWSResponse> scalaPromise = scala.concurrent.Promise$.MODULE$.apply();
final Promise<StandaloneWSResponse> streamStarted = scala.concurrent.Promise$.MODULE$.apply();
final Promise<Done> streamCompletion = scala.concurrent.Promise$.MODULE$.apply();

Function<StreamedState, StandaloneWSResponse> f = state -> {
Publisher<HttpResponseBodyPart> publisher = state.publisher();
Publisher<HttpResponseBodyPart> wrap = new Publisher<HttpResponseBodyPart>() {
@Override
public void subscribe(Subscriber<? super HttpResponseBodyPart> s) {
publisher.subscribe(
new Subscriber<HttpResponseBodyPart>() {
@Override
public void onSubscribe(Subscription sub) {
s.onSubscribe(sub);
}

@Override
public void onNext(HttpResponseBodyPart httpResponseBodyPart) {
s.onNext(httpResponseBodyPart);
}

@Override
public void onError(Throwable t) {
s.onError(t);
}

@Override
public void onComplete() {
FutureConverters$.MODULE$.toJava(streamCompletion.future())
.handle((d, t) -> {
if (d != null) s.onComplete();
else s.onError(t);
return null;
});
}
}
);
}
};

return new StreamedResponse(this,
state.statusCode(),
state.statusText(),
state.uriOption().get(),
state.responseHeaders(),
wrap);
};

asyncHttpClient.executeRequest(request, new DefaultStreamedAsyncHandler<>(state ->
new StreamedResponse(this,
state.statusCode(),
state.statusText(),
state.uriOption().get(),
state.responseHeaders(),
state.publisher()),
scalaPromise));
return FutureConverters.toJava(scalaPromise.future());
asyncHttpClient.executeRequest(request, new DefaultStreamedAsyncHandler<>(f,
streamStarted,
streamCompletion
));
return FutureConverters.toJava(streamStarted.future());
}

/**
* A convenience method for creating a StandaloneAhcWSClient from configuration.
*
* @param ahcWSClientConfig the configuration object
* @param materializer an akka materializer
* @param materializer an akka materializer
* @return a fully configured StandaloneAhcWSClient instance.
*
* @see #create(AhcWSClientConfig, AhcHttpCache, Materializer)
*/
public static StandaloneAhcWSClient create(AhcWSClientConfig ahcWSClientConfig, Materializer materializer) {
return create(
ahcWSClientConfig,
null /* no cache*/,
materializer
ahcWSClientConfig,
null /* no cache*/,
materializer
);
}

Expand Down Expand Up @@ -161,10 +206,10 @@ public static StandaloneAhcWSClient create(AhcWSClientConfig ahcWSClientConfig,
ByteString blockingToByteString(Source<ByteString, ?> bodyAsSource) {
try {
return bodyAsSource
.runFold(ByteString.createBuilder(), ByteStringBuilder::append, materializer)
.thenApply(ByteStringBuilder::result)
.toCompletableFuture()
.get();
.runFold(ByteString.createBuilder(), ByteStringBuilder::append, materializer)
.thenApply(ByteStringBuilder::result)
.toCompletableFuture()
.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,32 @@
*/
package play.api.libs.ws.ahc

import akka.Done
import javax.inject.Inject

import akka.stream.Materializer
import akka.stream.scaladsl.Source
import akka.util.ByteString
import com.typesafe.sslconfig.ssl.SystemConfiguration
import com.typesafe.sslconfig.ssl.debug.DebugConfiguration
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription
import play.api.libs.ws.ahc.cache._
import play.api.libs.ws.{ EmptyBody, StandaloneWSClient, StandaloneWSRequest }
import play.api.libs.ws.EmptyBody
import play.api.libs.ws.StandaloneWSClient
import play.api.libs.ws.StandaloneWSRequest
import play.shaded.ahc.org.asynchttpclient.uri.Uri
import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse, _ }
import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse }
import play.shaded.ahc.org.asynchttpclient._
import java.util.function.{ Function => JFunction }

import scala.collection.immutable.TreeMap
import scala.compat.java8.FunctionConverters
import scala.concurrent.{ Await, Future, Promise }
import scala.compat.java8.FunctionConverters._
import scala.concurrent.Await
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Failure
import scala.util.Success

/**
* A WS client backed by an AsyncHttpClient.
Expand All @@ -29,7 +40,10 @@ import scala.concurrent.{ Await, Future, Promise }
* also close asyncHttpClient.
* @param materializer A materializer, meant to execute the stream
*/
class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implicit materializer: Materializer) extends StandaloneWSClient {
class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(
implicit
materializer: Materializer
) extends StandaloneWSClient {

/** Returns instance of AsyncHttpClient */
def underlying[T]: T = asyncHttpClient.asInstanceOf[T]
Expand Down Expand Up @@ -58,7 +72,9 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
)
}

private[ahc] def execute(request: Request): Future[StandaloneAhcWSResponse] = {
private[ahc] def execute(
request: Request
): Future[StandaloneAhcWSResponse] = {
val result = Promise[StandaloneAhcWSResponse]()
val handler = new AsyncCompletionHandler[AHCResponse]() {
override def onCompleted(response: AHCResponse): AHCResponse = {
Expand Down Expand Up @@ -88,30 +104,71 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
}

private[ahc] def executeStream(request: Request): Future[StreamedResponse] = {
val promise = Promise[StreamedResponse]()

val function = FunctionConverters.asJavaFunction[StreamedState, StreamedResponse](state =>
new StreamedResponse(
this,
state.statusCode,
state.statusText,
state.uriOption.get,
state.responseHeaders,
state.publisher)
val streamStarted = Promise[StreamedResponse]()
val streamCompletion = Promise[Done]()

val client = this

val function: JFunction[StreamedState, StreamedResponse] = {
state: StreamedState =>
val publisher = state.publisher

val wrap = new Publisher[HttpResponseBodyPart]() {
override def subscribe(
s: Subscriber[_ >: HttpResponseBodyPart]
): Unit = {
publisher.subscribe(new Subscriber[HttpResponseBodyPart] {
override def onSubscribe(sub: Subscription): Unit =
s.onSubscribe(sub)

override def onNext(t: HttpResponseBodyPart): Unit = s.onNext(t)

override def onError(t: Throwable): Unit = s.onError(t)

override def onComplete(): Unit = {
streamCompletion.future.onComplete {
case Success(_) => s.onComplete()
case Failure(t) => s.onError(t)
}(materializer.executionContext)
}
})
}

}
new StreamedResponse(
client,
state.statusCode,
state.statusText,
state.uriOption.get,
state.responseHeaders,
wrap
)

}.asJava
asyncHttpClient.executeRequest(
request,
new DefaultStreamedAsyncHandler[StreamedResponse](
function,
streamStarted,
streamCompletion
)
)
asyncHttpClient.executeRequest(request, new DefaultStreamedAsyncHandler[StreamedResponse](function, promise))
promise.future
streamStarted.future
}

private[ahc] def blockingToByteString(bodyAsSource: Source[ByteString, _]) = {
StandaloneAhcWSClient.logger.warn(s"blockingToByteString is a blocking and unsafe operation!")
StandaloneAhcWSClient.logger.warn(
s"blockingToByteString is a blocking and unsafe operation!"
)

import scala.concurrent.ExecutionContext.Implicits.global

val limitedSource = bodyAsSource.limit(StandaloneAhcWSClient.elementLimit)
val result = limitedSource.runFold(ByteString.createBuilder) { (acc, bs) =>
acc.append(bs)
}.map(_.result())
val result = limitedSource
.runFold(ByteString.createBuilder) { (acc, bs) =>
acc.append(bs)
}
.map(_.result())

Await.result(result, StandaloneAhcWSClient.blockingTimeout)
}
Expand All @@ -125,7 +182,9 @@ object StandaloneAhcWSClient {
val elementLimit = 13 // 13 8192k blocks is roughly 100k
private val logger = org.slf4j.LoggerFactory.getLogger(this.getClass)

private[ahc] val loggerFactory = new AhcLoggerFactory(org.slf4j.LoggerFactory.getILoggerFactory)
private[ahc] val loggerFactory = new AhcLoggerFactory(
org.slf4j.LoggerFactory.getILoggerFactory
)

/**
* Convenient factory method that uses a play.api.libs.ws.WSClientConfig value for configuration instead of
Expand All @@ -146,21 +205,26 @@ object StandaloneAhcWSClient {
* @param httpCache if not null, will be used for HTTP response caching.
* @param materializer the akka materializer.
*/
def apply(config: AhcWSClientConfig = AhcWSClientConfigFactory.forConfig(), httpCache: Option[AhcHttpCache] = None)(implicit materializer: Materializer): StandaloneAhcWSClient = {
def apply(
config: AhcWSClientConfig = AhcWSClientConfigFactory.forConfig(),
httpCache: Option[AhcHttpCache] = None
)(implicit materializer: Materializer): StandaloneAhcWSClient = {
if (config.wsClientConfig.ssl.debug.enabled) {
new DebugConfiguration(StandaloneAhcWSClient.loggerFactory).configure(config.wsClientConfig.ssl.debug)
new DebugConfiguration(StandaloneAhcWSClient.loggerFactory)
.configure(config.wsClientConfig.ssl.debug)
}
val ahcConfig = new AhcConfigBuilder(config).build()
val asyncHttpClient = new DefaultAsyncHttpClient(ahcConfig)
val wsClient = new StandaloneAhcWSClient(
httpCache.map { cache =>
new CachingAsyncHttpClient(asyncHttpClient, cache)
}.getOrElse {
asyncHttpClient
}
httpCache
.map { cache =>
new CachingAsyncHttpClient(asyncHttpClient, cache)
}
.getOrElse {
asyncHttpClient
}
)
new SystemConfiguration(loggerFactory).configure(config.wsClientConfig.ssl)
wsClient
}
}

Loading

0 comments on commit f99dbb2

Please sign in to comment.