Skip to content

Commit

Permalink
Fix keepalive logic (#50)
Browse files Browse the repository at this point in the history
This PR fixes grpc/grpc-swift#2095.

## Motivation

As per the gRPC specification, the server must keep track of pings from
each client, and if they go over a threshold, we must send a GOAWAY
frame and close the connection. We must reset the number of ping strikes
every time the server writes a headers or data frame. However, there is
a bug in the current keepalive implementation and we are not properly
keeping track of when header/data frames are written, so we never reset
the strikes, causing the server to always end up closing connections
when keepalive pings are enabled.

There was also a second bug where the GOAWAY frame wasn't actually sent
to the client because we were closing the connection straight away, and
the packet never made it out.

## Modifications

This PR fixes a couple of bugs:
- It keeps track of the appropriate FrameStats as described above
- It delays the channel close after sending the GOAWAY packet by a tick
to make sure it gets flushed and delivered to the client

## Results

Fewer bugs!
  • Loading branch information
gjcairo authored Dec 20, 2024
1 parent 289c0bc commit 47e0be1
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ extension ChannelPipeline.SynchronousOperations {
var http2HandlerStreamConfiguration = NIOHTTP2Handler.StreamConfiguration()
http2HandlerStreamConfiguration.targetWindowSize = clampedTargetWindowSize

let boundConnectionManagementHandler = NIOLoopBound(
serverConnectionHandler.syncView,
eventLoop: self.eventLoop
)
let streamMultiplexer = try self.configureAsyncHTTP2Pipeline(
mode: .server,
streamDelegate: serverConnectionHandler.http2StreamDelegate,
Expand All @@ -86,7 +90,8 @@ extension ChannelPipeline.SynchronousOperations {
acceptedEncodings: compressionConfig.enabledAlgorithms,
maxPayloadSize: rpcConfig.maxRequestPayloadSize,
methodDescriptorPromise: methodDescriptorPromise,
eventLoop: streamChannel.eventLoop
eventLoop: streamChannel.eventLoop,
connectionManagementHandler: boundConnectionManagementHandler.value
)
try streamChannel.pipeline.syncOperations.addHandler(streamHandler)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler {
}

/// Stats about recently written frames. Used to determine whether to reset keep-alive state.
private var frameStats: FrameStats
package var frameStats: FrameStats

struct FrameStats {
package struct FrameStats {
private(set) var didWriteHeadersOrData = false

/// Mark that a HEADERS frame has been written.
Expand Down Expand Up @@ -609,7 +609,13 @@ extension ServerConnectionManagementHandler {

context.write(self.wrapOutboundOut(goAway), promise: nil)
self.maybeFlush(context: context)
context.close(promise: nil)

// We must delay the channel close after sending the GOAWAY packet by a tick to make sure it
// gets flushed and delivered to the client before the connection is closed.
let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop)
context.eventLoop.execute {
loopBound.value.close(promise: nil)
}

case .sendAck:
() // ACKs are sent by NIO's HTTP/2 handler, don't double ack.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan

private var cancellationHandle: Optional<ServerContext.RPCCancellationHandle>

package let connectionManagementHandler: ServerConnectionManagementHandler.SyncView

// Existential errors unconditionally allocate, avoid this per-use allocation by doing it
// statically.
private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError(
Expand All @@ -55,6 +57,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
maxPayloadSize: Int,
methodDescriptorPromise: EventLoopPromise<MethodDescriptor>,
eventLoop: any EventLoop,
connectionManagementHandler: ServerConnectionManagementHandler.SyncView,
cancellationHandler: ServerContext.RPCCancellationHandle? = nil,
skipStateMachineAssertions: Bool = false
) {
Expand All @@ -66,6 +69,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan
self.methodDescriptorPromise = methodDescriptorPromise
self.cancellationHandle = cancellationHandler
self.eventLoop = eventLoop
self.connectionManagementHandler = connectionManagementHandler
}

package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) {
Expand Down Expand Up @@ -136,13 +140,16 @@ extension GRPCServerStreamHandler {
switch self.stateMachine.nextInboundMessage() {
case .receiveMessage(let message):
context.fireChannelRead(self.wrapInboundOut(.message(message)))

case .awaitMoreMessages:
break loop

case .noMoreMessages:
context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
break loop
}
}

case .doNothing:
()
}
Expand Down Expand Up @@ -261,6 +268,7 @@ extension GRPCServerStreamHandler {
self.flushPending = true
let headers = try self.stateMachine.send(metadata: metadata)
context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise)
self.connectionManagementHandler.wroteHeadersFrame()
} catch let invalidState {
let error = RPCError(invalidState)
promise?.fail(error)
Expand All @@ -270,6 +278,7 @@ extension GRPCServerStreamHandler {
case .message(let message):
do {
try self.stateMachine.send(message: message, promise: promise)
self.connectionManagementHandler.wroteDataFrame()
} catch let invalidState {
let error = RPCError(invalidState)
promise?.fail(error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,24 @@ extension ConnectionTest {
let h2 = NIOHTTP2Handler(mode: .server)
let mux = HTTP2StreamMultiplexer(mode: .server, channel: channel) { stream in
let sync = stream.pipeline.syncOperations
let connectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: stream.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)
let handler = GRPCServerStreamHandler(
scheme: .http,
acceptedEncodings: .none,
maxPayloadSize: .max,
methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
eventLoop: stream.eventLoop
eventLoop: stream.eventLoop,
connectionManagementHandler: connectionManagementHandler.syncView
)

return stream.eventLoop.makeCompletedFuture {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,24 @@ final class TestServer: Sendable {
let sync = channel.pipeline.syncOperations
let multiplexer = try sync.configureAsyncHTTP2Pipeline(mode: .server) { stream in
stream.eventLoop.makeCompletedFuture {
let connectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: stream.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)
let handler = GRPCServerStreamHandler(
scheme: .http,
acceptedEncodings: .all,
maxPayloadSize: .max,
methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self),
eventLoop: stream.eventLoop
eventLoop: stream.eventLoop,
connectionManagementHandler: connectionManagementHandler.syncView
)

try stream.pipeline.syncOperations.addHandlers(handler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,25 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
disableAssertions: Bool = false
) -> GRPCServerStreamHandler {
let serverConnectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: channel.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)

return GRPCServerStreamHandler(
scheme: scheme,
acceptedEncodings: acceptedEncodings,
maxPayloadSize: maxPayloadSize,
methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
eventLoop: channel.eventLoop,
connectionManagementHandler: serverConnectionManagementHandler.syncView,
skipStateMachineAssertions: disableAssertions
)
}
Expand Down Expand Up @@ -974,28 +987,50 @@ final class GRPCServerStreamHandlerTests: XCTestCase {
}

struct ServerStreamHandlerTests {
private func makeServerStreamHandler(
struct ConnectionAndStreamHandlers {
let streamHandler: GRPCServerStreamHandler
let connectionHandler: ServerConnectionManagementHandler
}

private func makeServerConnectionAndStreamHandlers(
channel: any Channel,
scheme: Scheme = .http,
acceptedEncodings: CompressionAlgorithmSet = [],
maxPayloadSize: Int = .max,
descriptorPromise: EventLoopPromise<MethodDescriptor>? = nil,
disableAssertions: Bool = false
) -> GRPCServerStreamHandler {
return GRPCServerStreamHandler(
) -> ConnectionAndStreamHandlers {
let connectionManagementHandler = ServerConnectionManagementHandler(
eventLoop: channel.eventLoop,
maxIdleTime: nil,
maxAge: nil,
maxGraceTime: nil,
keepaliveTime: nil,
keepaliveTimeout: nil,
allowKeepaliveWithoutCalls: false,
minPingIntervalWithoutCalls: .minutes(5),
requireALPN: false
)
let streamHandler = GRPCServerStreamHandler(
scheme: scheme,
acceptedEncodings: acceptedEncodings,
maxPayloadSize: maxPayloadSize,
methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(),
eventLoop: channel.eventLoop,
connectionManagementHandler: connectionManagementHandler.syncView,
skipStateMachineAssertions: disableAssertions
)

return ConnectionAndStreamHandlers(
streamHandler: streamHandler,
connectionHandler: connectionManagementHandler
)
}

@Test("ChannelShouldQuiesceEvent is buffered and turns into RPC cancellation")
func shouldQuiesceEventIsBufferedBeforeHandleIsSet() async throws {
let channel = EmbeddedChannel()
let handler = self.makeServerStreamHandler(channel: channel)
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
try channel.pipeline.syncOperations.addHandler(handler)
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())

Expand All @@ -1011,7 +1046,7 @@ struct ServerStreamHandlerTests {
@Test("ChannelShouldQuiesceEvent turns into RPC cancellation")
func shouldQuiesceEventTriggersCancellation() async throws {
let channel = EmbeddedChannel()
let handler = self.makeServerStreamHandler(channel: channel)
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
try channel.pipeline.syncOperations.addHandler(handler)

await withServerContextRPCCancellationHandle { handle in
Expand All @@ -1028,7 +1063,7 @@ struct ServerStreamHandlerTests {
@Test("RST_STREAM turns into RPC cancellation")
func rstStreamTriggersCancellation() async throws {
let channel = EmbeddedChannel()
let handler = self.makeServerStreamHandler(channel: channel)
let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler
try channel.pipeline.syncOperations.addHandler(handler)

await withServerContextRPCCancellationHandle { handle in
Expand All @@ -1045,6 +1080,51 @@ struct ServerStreamHandlerTests {
_ = try? channel.finish()
}

@Test("Connection FrameStats are updated when writing headers or data frames")
func connectionFrameStatsAreUpdatedAccordingly() async throws {
let channel = EmbeddedChannel()
let handlers = self.makeServerConnectionAndStreamHandlers(channel: channel)
try channel.pipeline.syncOperations.addHandler(handlers.streamHandler)

// We have written nothing yet, so expect FrameStats/didWriteHeadersOrData to be false
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// FrameStats aren't affected by pings received
channel.pipeline.fireChannelRead(
NIOAny(HTTP2Frame.FramePayload.ping(.init(withInteger: 42), ack: false))
)
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// Now write back headers and make sure FrameStats are updated accordingly:
// To do that, we first need to receive client's initial metadata...
let clientInitialMetadata: HPACKHeaders = [
GRPCHTTP2Keys.path.rawValue: "/SomeService/SomeMethod",
GRPCHTTP2Keys.scheme.rawValue: "http",
GRPCHTTP2Keys.method.rawValue: "POST",
GRPCHTTP2Keys.contentType.rawValue: "application/grpc",
GRPCHTTP2Keys.te.rawValue: "trailers",
]
try channel.writeInbound(
HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata))
)

// Now we write back server's initial metadata...
let serverInitialMetadata = RPCResponsePart.metadata([:])
try channel.writeOutbound(serverInitialMetadata)

// And this should have updated the FrameStats
#expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// Manually reset the FrameStats to make sure that writing data also updates it correctly.
handlers.connectionHandler.frameStats.reset()
#expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData)
try channel.writeOutbound(RPCResponsePart.message([42]))
#expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData)

// Clean up.
// Throwing is fine: the channel is closed abruptly, errors are expected.
_ = try? channel.finish()
}
}

extension EmbeddedChannel {
Expand Down

0 comments on commit 47e0be1

Please sign in to comment.