Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for client and server message buffering #28

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ class GrpcServicePrinter(
private def grpcDescriptor(method: MethodDescriptor): ScalaName =
companionObject(method.getService) / s"METHOD_${NameUtils.toAllCaps(method.getName)}"

private[this] val serverCalls = "_root_.io.grpc.stub.ServerCalls"

private[this] def methodDescriptor(method: MethodDescriptor) = PrinterEndo { p =>
def marshaller(t: ExtendedMethodDescriptor#MethodTypeWrapper) =
if (t.customScalaType.isDefined)
Expand Down
157 changes: 115 additions & 42 deletions grpc-runtime/src/main/scala/monix/grpc/runtime/server/ServerCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,35 @@ import monix.eval.Task
import monix.execution.{AsyncVar, BufferCapacity}
import monix.reactive.Observable

// TODO: Add attributes, compression, message compression.
class ServerCall[Request, Response] private (
val call: grpc.ServerCall[Request, Response]
) extends AnyVal {

def isReady: Boolean = call.isReady
/**
* Defines a server call that accepts a client {@tparam Request} and returns a
* server {@tparam Response}. The following definition is the Monix-based
* version of the wrapped `grpc.ServerCall` and thus tries to resemble the
* underlying API as much as possible while accepting `Task` and `Observable`
* as input and output types.
*
* @param call is the instance of the wrapped grpc server call.
* @param options is a field for custom user-defined server call options.
*/
final class ServerCall[Request, Response] private (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that is not clear to me yet is how you set this ServerCallOptions.
Is meant to be set while registering the Api to the GrpcServer?

val call: grpc.ServerCall[Request, Response],
val options: ServerCallOptions
) {

/**
* Requests up to the given number of messages from the call to be delivered
* to{@link Listener#onMessage(Object)}. Once {@code numMessages} have been
* delivered no further request messages will be delivered until more
* messages are requested by calling this method again.
*
* Servers use this mechanism to provide back-pressure to the client for
* flow-control. This method is safe to call from multiple threads without
* external synchronization.
*
* @param numMessages the number of messages to be delivered to the listener.
*/
def request(numMessages: Int): Task[Unit] =
handleError(
Task(call.request(numMessages)),
s"Failed to request message $numMessages!"
)
handleError(Task(call.request(numMessages)), s"Failed to request message $numMessages!")

/**
* Asks for two messages even though we expect only one so that if a
Expand All @@ -28,12 +45,47 @@ class ServerCall[Request, Response] private (
*/
def requestMessagesFromUnaryCall: Task[Unit] = request(2)

/**
* Sends response header metadata prior to sending a response message. This
* method may only be called once and cannot be called after calls to
* {@link #sendMessage} or {@link #close}.
*
* Since {@link Metadata} is not thread-safe, the caller must not access
* (read or write) {@code headers} after this point.
*
* @param headers metadata to send prior to any response body.
* @return a task that will complete when headers are successfully sent or
* that will throw an `IllegalStateException` error if{@code close} has
* been called, a message has been sent, or headers have already been sent
*/
def sendHeaders(headers: grpc.Metadata): Task[Unit] =
handleError(Task(call.sendHeaders(headers)), s"Failed to send headers!", headers)

/**
* Sends a response message. Messages are the primary form of communication
* associated with RPCs. Multiple response messages may exist for streaming
* calls.
*
* @param message is the response message to send to the client.
* @return a task that will complete when message is successfully sent or
* that will throw an `IllegalStateException` error if headers not sent or
* call is {@link #close}d
*/
def sendMessage(message: Response): Task[Unit] =
handleError(Task(call.sendMessage(message)), s"Failed to send message $message!")

/**
* Subscribed to the `responses` observable provided as a parameter and sends
* each received response to the client with built-in flow control.
*
* @param responses is the observable that streams responses.
* @param onReady is an async var that will be full when the client is ready
* to receive another message and thus the server can send a response.
*
* @return a task that will complete when all messages are successfully sent or
* that will throw an `IllegalStateException` error if headers not sent or
* call is {@link #close}d
*/
def sendStreamingResponses(
responses: Observable[Response],
onReady: AsyncVar[Unit]
Expand All @@ -48,9 +100,48 @@ class ServerCall[Request, Response] private (
}.completedL
}

def closeStream(status: grpc.Status, trailers: grpc.Metadata): Task[Unit] =
/**
* Close the call with the provided status. No further sending or receiving
* will occur. If {@link Status#isOk} is {@code false}, then the call is
* said to have failed.
*
* If no errors or cancellations are known to have occurred, then a
* {@link Listener#onComplete} notification should be expected, independent
* of {@code status}. Otherwise {@link Listener#onCancel} has been or will
* be called.
*
* Since {@link Metadata} is not thread-safe, the caller must not access
* (read or write) {@code trailers} after this point.
*
* This method implies the caller completed processing the RPC, but it does
* not imply the RPC is complete. The call implementation will need
* additional time to complete the RPC and during this time the client is
* still able to cancel the request or a network error might cause the RPC
* to fail. If you wish to know when the call is actually completed/closed,
* you have to use{@link Listener#onComplete} or {@link Listener#onCancel}
* instead. This method is not necessarily invoked when Listener.onCancel
* () is called.
*
* @return a task that will complete when successfully closed and will throw
* `IllegalStateException` if call is already {@code close}d.
*/
def close(status: grpc.Status, trailers: grpc.Metadata): Task[Unit] =
Task.delay(call.close(status, trailers))

/**
* If {@code true}, indicates that the call is capable of sending additional
* messages without requiring excessive buffering internally. This event is
* just a suggestion and the application is free to ignore it, however
* doing so may result in excessive buffering within the call.
*
* If {@code false}, {@link Listener#onReady()} will be called after
* {@code isReady()} transitions to {@code true}.
*
* This abstract class's implementation always returns {@code true}.
* Implementations generally override the method.
*/
private def isReady: Boolean = call.isReady

private def handleError(
effect: Task[Unit],
errorMsg: String,
Expand All @@ -65,40 +156,22 @@ class ServerCall[Request, Response] private (
}

object ServerCall {

/**
* Creates a server call that accepts a client {@tparam Request} and returns
* a server {@tparam Response}. The following definition supports Monix and
* automatic flow control out-of-the-box and so embraces `Task` and
* `Observable` as part of its API.
*/
def apply[Request, Response](
call: grpc.ServerCall[Request, Response],
options: ServerCallOptions
): ServerCall[Request, Response] = {
val compressions = options.compressor.map(_.name)
compressions.foreach(call.setCompression)
new ServerCall(call)
options.enabledMessageCompression.foreach(call.setMessageCompression)
options.compressor.foreach { compressor =>
grpc.CompressorRegistry.getDefaultInstance().register(compressor)
compressor.getMessageEncoding()
}
new ServerCall(call, options)
}
}

abstract class ServerCallOptions private (
val compressor: Option[ServerCompressor],
val bufferCapacity: BufferCapacity
) {
//needs to be private for binary compatibility
private def copy(
compressor: Option[ServerCompressor] = this.compressor,
bufferCapacity: BufferCapacity
): ServerCallOptions = new ServerCallOptions(compressor, bufferCapacity) {}

def withServerCompressor(
compressor: Option[ServerCompressor]
): ServerCallOptions = copy(compressor, bufferCapacity)

def withServerCompressor(
bufferCapacity: BufferCapacity
): ServerCallOptions = copy(compressor, bufferCapacity)
}

object ServerCallOptions {
val default: ServerCallOptions =
new ServerCallOptions(Some(GzipCompressor), BufferCapacity.Bounded(32)) {}
}

abstract sealed class ServerCompressor(val name: String) extends Product with Serializable

case object GzipCompressor extends ServerCompressor("gzip")
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object ServerCallHandlers {
*/
def unaryToUnaryCall[T, R](
f: (T, grpc.Metadata) => Task[R],
options: ServerCallOptions = ServerCallOptions.default
options: ServerCallOptions = ServerCallOptions()
)(implicit
scheduler: Scheduler
): grpc.ServerCallHandler[T, R] = new grpc.ServerCallHandler[T, R] {
Expand Down Expand Up @@ -53,7 +53,7 @@ object ServerCallHandlers {
*/
def unaryToStreamingCall[T, R](
f: (T, grpc.Metadata) => Observable[R],
options: ServerCallOptions = ServerCallOptions.default
options: ServerCallOptions = ServerCallOptions()
)(implicit
scheduler: Scheduler
): grpc.ServerCallHandler[T, R] = new grpc.ServerCallHandler[T, R] {
Expand Down Expand Up @@ -138,7 +138,7 @@ object ServerCallHandlers {
*/
def streamingToUnaryCall[T, R](
f: (Observable[T], grpc.Metadata) => Task[R],
options: ServerCallOptions = ServerCallOptions.default
options: ServerCallOptions = ServerCallOptions()
)(implicit
scheduler: Scheduler
): grpc.ServerCallHandler[T, R] = new grpc.ServerCallHandler[T, R] {
Expand All @@ -147,7 +147,7 @@ object ServerCallHandlers {
metadata: grpc.Metadata
): grpc.ServerCall.Listener[T] = {
val call = ServerCall(grpcCall, options)
val listener = new StreamingCallListener(call, options.bufferCapacity)(scheduler)
val listener = new StreamingCallListener(call)(scheduler)
listener.runStreamingResponseListener(metadata) { msgs =>
Task.defer(f(msgs, metadata)).flatMap(call.sendMessage)
}
Expand All @@ -166,7 +166,7 @@ object ServerCallHandlers {
*/
def streamingToStreamingCall[T, R](
f: (Observable[T], grpc.Metadata) => Observable[R],
options: ServerCallOptions = ServerCallOptions.default
options: ServerCallOptions = ServerCallOptions()
)(implicit
scheduler: Scheduler
): grpc.ServerCallHandler[T, R] = new grpc.ServerCallHandler[T, R] {
Expand All @@ -176,7 +176,7 @@ object ServerCallHandlers {
): grpc.ServerCall.Listener[T] = {

val call = ServerCall(grpcCall, options)
val listener = new StreamingCallListener(call, options.bufferCapacity)(scheduler)
val listener = new StreamingCallListener(call)(scheduler)
listener.runStreamingResponseListener(metadata) { msgs =>
call.sendStreamingResponses(
Observable.defer(f(msgs, metadata)),
Expand All @@ -188,8 +188,7 @@ object ServerCallHandlers {
}

private[server] final class StreamingCallListener[Request, Response](
call: ServerCall[Request, Response],
capacity: BufferCapacity
call: ServerCall[Request, Response]
)(implicit
scheduler: Scheduler
) extends grpc.ServerCall.Listener[Request] {
Expand All @@ -201,17 +200,24 @@ object ServerCallHandlers {

val onReadyEffect: AsyncVar[Unit] = AsyncVar.empty[Unit]()

private[this] val bufferSize = call.options.bufferSize.getOrElse(0)
def runStreamingResponseListener(
metadata: grpc.Metadata
)(
sendResponses: Observable[Request] => Task[Unit]
): Unit = {
def bufferResponsesUpTo(responses: Observable[Request]) =
if (bufferSize == 0) responses
else responses.asyncBoundary(OverflowStrategy.BackPressure(bufferSize))

val handleResponse = for {
_ <- call.sendHeaders(metadata)
_ <- sendResponses {
subject
.doAfterSubscribe(call.request(1))
.doOnNext(_ => call.request(1))
bufferResponsesUpTo {
subject
.doAfterSubscribe(call.request(1))
.doOnNext(_ => call.request(1))
}
}
} yield ()

Expand Down Expand Up @@ -244,11 +250,11 @@ object ServerCallHandlers {
isCancelled: CancelablePromise[Unit]
): Task[Unit] = {
val finalHandler = handleResponse.guaranteeCase {
case ExitCase.Completed => call.closeStream(grpc.Status.OK, new grpc.Metadata())
case ExitCase.Completed => call.close(grpc.Status.OK, new grpc.Metadata())
case ExitCase.Error(err) => reportError(err, call, new grpc.Metadata())
case ExitCase.Canceled =>
val description = "Propagating cancellation because server response handler was cancelled!"
call.closeStream(grpc.Status.CANCELLED.withDescription(description), new grpc.Metadata())
call.close(grpc.Status.CANCELLED.withDescription(description), new grpc.Metadata())
}

// If `isCancelled` is completed, then client cancelled the grpc call and
Expand All @@ -264,13 +270,13 @@ object ServerCallHandlers {
err match {
case err: grpc.StatusException =>
val metadata = Option(err.getTrailers).getOrElse(new grpc.Metadata())
call.closeStream(err.getStatus, metadata)
call.close(err.getStatus, metadata)
case err: grpc.StatusRuntimeException =>
val metadata = Option(err.getTrailers).getOrElse(new grpc.Metadata())
call.closeStream(err.getStatus, metadata)
call.close(err.getStatus, metadata)
case err =>
val status = grpc.Status.INTERNAL.withDescription(err.getMessage).withCause(err)
call.closeStream(status, unknownErrorMetadata)
call.close(status, unknownErrorMetadata)
}
}
}
Loading