diff --git a/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift b/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift index b53dd08..89a95ae 100644 --- a/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift +++ b/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift @@ -71,6 +71,10 @@ extension ChannelPipeline.SynchronousOperations { var http2HandlerStreamConfiguration = NIOHTTP2Handler.StreamConfiguration() http2HandlerStreamConfiguration.targetWindowSize = clampedTargetWindowSize + let boundConnectionManagementHandler = NIOLoopBound( + serverConnectionHandler.syncView, + eventLoop: self.eventLoop + ) let streamMultiplexer = try self.configureAsyncHTTP2Pipeline( mode: .server, streamDelegate: serverConnectionHandler.http2StreamDelegate, @@ -86,7 +90,8 @@ extension ChannelPipeline.SynchronousOperations { acceptedEncodings: compressionConfig.enabledAlgorithms, maxPayloadSize: rpcConfig.maxRequestPayloadSize, methodDescriptorPromise: methodDescriptorPromise, - eventLoop: streamChannel.eventLoop + eventLoop: streamChannel.eventLoop, + connectionManagementHandler: boundConnectionManagementHandler.value ) try streamChannel.pipeline.syncOperations.addHandler(streamHandler) diff --git a/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift b/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift index afc10e3..78f94db 100644 --- a/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift +++ b/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift @@ -121,9 +121,9 @@ package final class ServerConnectionManagementHandler: ChannelDuplexHandler { } /// Stats about recently written frames. Used to determine whether to reset keep-alive state. - private var frameStats: FrameStats + package var frameStats: FrameStats - struct FrameStats { + package struct FrameStats { private(set) var didWriteHeadersOrData = false /// Mark that a HEADERS frame has been written. @@ -609,7 +609,13 @@ extension ServerConnectionManagementHandler { context.write(self.wrapOutboundOut(goAway), promise: nil) self.maybeFlush(context: context) - context.close(promise: nil) + + // We must delay the channel close after sending the GOAWAY packet by a tick to make sure it + // gets flushed and delivered to the client before the connection is closed. + let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop) + context.eventLoop.execute { + loopBound.value.close(promise: nil) + } case .sendAck: () // ACKs are sent by NIO's HTTP/2 handler, don't double ack. diff --git a/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift b/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift index 22c7eef..4482a8b 100644 --- a/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift @@ -42,6 +42,8 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan private var cancellationHandle: Optional + package let connectionManagementHandler: ServerConnectionManagementHandler.SyncView + // Existential errors unconditionally allocate, avoid this per-use allocation by doing it // statically. private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError( @@ -55,6 +57,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan maxPayloadSize: Int, methodDescriptorPromise: EventLoopPromise, eventLoop: any EventLoop, + connectionManagementHandler: ServerConnectionManagementHandler.SyncView, cancellationHandler: ServerContext.RPCCancellationHandle? = nil, skipStateMachineAssertions: Bool = false ) { @@ -66,6 +69,7 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan self.methodDescriptorPromise = methodDescriptorPromise self.cancellationHandle = cancellationHandler self.eventLoop = eventLoop + self.connectionManagementHandler = connectionManagementHandler } package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) { @@ -136,13 +140,16 @@ extension GRPCServerStreamHandler { switch self.stateMachine.nextInboundMessage() { case .receiveMessage(let message): context.fireChannelRead(self.wrapInboundOut(.message(message))) + case .awaitMoreMessages: break loop + case .noMoreMessages: context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) break loop } } + case .doNothing: () } @@ -261,6 +268,7 @@ extension GRPCServerStreamHandler { self.flushPending = true let headers = try self.stateMachine.send(metadata: metadata) context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: promise) + self.connectionManagementHandler.wroteHeadersFrame() } catch let invalidState { let error = RPCError(invalidState) promise?.fail(error) @@ -270,6 +278,7 @@ extension GRPCServerStreamHandler { case .message(let message): do { try self.stateMachine.send(message: message, promise: promise) + self.connectionManagementHandler.wroteDataFrame() } catch let invalidState { let error = RPCError(invalidState) promise?.fail(error) diff --git a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift index c64a2f6..9bf2537 100644 --- a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift +++ b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift @@ -114,12 +114,24 @@ extension ConnectionTest { let h2 = NIOHTTP2Handler(mode: .server) let mux = HTTP2StreamMultiplexer(mode: .server, channel: channel) { stream in let sync = stream.pipeline.syncOperations + let connectionManagementHandler = ServerConnectionManagementHandler( + eventLoop: stream.eventLoop, + maxIdleTime: nil, + maxAge: nil, + maxGraceTime: nil, + keepaliveTime: nil, + keepaliveTimeout: nil, + allowKeepaliveWithoutCalls: false, + minPingIntervalWithoutCalls: .minutes(5), + requireALPN: false + ) let handler = GRPCServerStreamHandler( scheme: .http, acceptedEncodings: .none, maxPayloadSize: .max, methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self), - eventLoop: stream.eventLoop + eventLoop: stream.eventLoop, + connectionManagementHandler: connectionManagementHandler.syncView ) return stream.eventLoop.makeCompletedFuture { diff --git a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift index 01ec6ad..452320b 100644 --- a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift +++ b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift @@ -70,12 +70,24 @@ final class TestServer: Sendable { let sync = channel.pipeline.syncOperations let multiplexer = try sync.configureAsyncHTTP2Pipeline(mode: .server) { stream in stream.eventLoop.makeCompletedFuture { + let connectionManagementHandler = ServerConnectionManagementHandler( + eventLoop: stream.eventLoop, + maxIdleTime: nil, + maxAge: nil, + maxGraceTime: nil, + keepaliveTime: nil, + keepaliveTimeout: nil, + allowKeepaliveWithoutCalls: false, + minPingIntervalWithoutCalls: .minutes(5), + requireALPN: false + ) let handler = GRPCServerStreamHandler( scheme: .http, acceptedEncodings: .all, maxPayloadSize: .max, methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self), - eventLoop: stream.eventLoop + eventLoop: stream.eventLoop, + connectionManagementHandler: connectionManagementHandler.syncView ) try stream.pipeline.syncOperations.addHandlers(handler) diff --git a/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift index 6319e80..74bdf62 100644 --- a/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -33,12 +33,25 @@ final class GRPCServerStreamHandlerTests: XCTestCase { descriptorPromise: EventLoopPromise? = nil, disableAssertions: Bool = false ) -> GRPCServerStreamHandler { + let serverConnectionManagementHandler = ServerConnectionManagementHandler( + eventLoop: channel.eventLoop, + maxIdleTime: nil, + maxAge: nil, + maxGraceTime: nil, + keepaliveTime: nil, + keepaliveTimeout: nil, + allowKeepaliveWithoutCalls: false, + minPingIntervalWithoutCalls: .minutes(5), + requireALPN: false + ) + return GRPCServerStreamHandler( scheme: scheme, acceptedEncodings: acceptedEncodings, maxPayloadSize: maxPayloadSize, methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(), eventLoop: channel.eventLoop, + connectionManagementHandler: serverConnectionManagementHandler.syncView, skipStateMachineAssertions: disableAssertions ) } @@ -974,28 +987,50 @@ final class GRPCServerStreamHandlerTests: XCTestCase { } struct ServerStreamHandlerTests { - private func makeServerStreamHandler( + struct ConnectionAndStreamHandlers { + let streamHandler: GRPCServerStreamHandler + let connectionHandler: ServerConnectionManagementHandler + } + + private func makeServerConnectionAndStreamHandlers( channel: any Channel, scheme: Scheme = .http, acceptedEncodings: CompressionAlgorithmSet = [], maxPayloadSize: Int = .max, descriptorPromise: EventLoopPromise? = nil, disableAssertions: Bool = false - ) -> GRPCServerStreamHandler { - return GRPCServerStreamHandler( + ) -> ConnectionAndStreamHandlers { + let connectionManagementHandler = ServerConnectionManagementHandler( + eventLoop: channel.eventLoop, + maxIdleTime: nil, + maxAge: nil, + maxGraceTime: nil, + keepaliveTime: nil, + keepaliveTimeout: nil, + allowKeepaliveWithoutCalls: false, + minPingIntervalWithoutCalls: .minutes(5), + requireALPN: false + ) + let streamHandler = GRPCServerStreamHandler( scheme: scheme, acceptedEncodings: acceptedEncodings, maxPayloadSize: maxPayloadSize, methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(), eventLoop: channel.eventLoop, + connectionManagementHandler: connectionManagementHandler.syncView, skipStateMachineAssertions: disableAssertions ) + + return ConnectionAndStreamHandlers( + streamHandler: streamHandler, + connectionHandler: connectionManagementHandler + ) } @Test("ChannelShouldQuiesceEvent is buffered and turns into RPC cancellation") func shouldQuiesceEventIsBufferedBeforeHandleIsSet() async throws { let channel = EmbeddedChannel() - let handler = self.makeServerStreamHandler(channel: channel) + let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler try channel.pipeline.syncOperations.addHandler(handler) channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) @@ -1011,7 +1046,7 @@ struct ServerStreamHandlerTests { @Test("ChannelShouldQuiesceEvent turns into RPC cancellation") func shouldQuiesceEventTriggersCancellation() async throws { let channel = EmbeddedChannel() - let handler = self.makeServerStreamHandler(channel: channel) + let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler try channel.pipeline.syncOperations.addHandler(handler) await withServerContextRPCCancellationHandle { handle in @@ -1028,7 +1063,7 @@ struct ServerStreamHandlerTests { @Test("RST_STREAM turns into RPC cancellation") func rstStreamTriggersCancellation() async throws { let channel = EmbeddedChannel() - let handler = self.makeServerStreamHandler(channel: channel) + let handler = self.makeServerConnectionAndStreamHandlers(channel: channel).streamHandler try channel.pipeline.syncOperations.addHandler(handler) await withServerContextRPCCancellationHandle { handle in @@ -1045,6 +1080,51 @@ struct ServerStreamHandlerTests { _ = try? channel.finish() } + @Test("Connection FrameStats are updated when writing headers or data frames") + func connectionFrameStatsAreUpdatedAccordingly() async throws { + let channel = EmbeddedChannel() + let handlers = self.makeServerConnectionAndStreamHandlers(channel: channel) + try channel.pipeline.syncOperations.addHandler(handlers.streamHandler) + + // We have written nothing yet, so expect FrameStats/didWriteHeadersOrData to be false + #expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData) + + // FrameStats aren't affected by pings received + channel.pipeline.fireChannelRead( + NIOAny(HTTP2Frame.FramePayload.ping(.init(withInteger: 42), ack: false)) + ) + #expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData) + + // Now write back headers and make sure FrameStats are updated accordingly: + // To do that, we first need to receive client's initial metadata... + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "/SomeService/SomeMethod", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + + // Now we write back server's initial metadata... + let serverInitialMetadata = RPCResponsePart.metadata([:]) + try channel.writeOutbound(serverInitialMetadata) + + // And this should have updated the FrameStats + #expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData) + + // Manually reset the FrameStats to make sure that writing data also updates it correctly. + handlers.connectionHandler.frameStats.reset() + #expect(!handlers.connectionHandler.frameStats.didWriteHeadersOrData) + try channel.writeOutbound(RPCResponsePart.message([42])) + #expect(handlers.connectionHandler.frameStats.didWriteHeadersOrData) + + // Clean up. + // Throwing is fine: the channel is closed abruptly, errors are expected. + _ = try? channel.finish() + } } extension EmbeddedChannel {