From 8f8dd9fd4f8f0f31874f799748761f8ea7ee98d4 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Tue, 19 Nov 2024 11:27:37 +0000 Subject: [PATCH] Added WebSocketNonNegotiableExtensionBuilder (#5) * Added WebSocketNonNegotiableExtensionBuilder * Fix test * Make WebSocketNonNegotiableExtensionBuilder a struct * nonNegociatedExtension * Add comment for nonNegotiatedExtension --- Sources/WSClient/WebSocketClientChannel.swift | 14 +- .../Extensions/WebSocketExtension.swift | 40 ++++ .../WebSocketExtensionBuilder.swift | 151 +++++++++++++++ .../WebSocketExtensionHTTPParameters.swift | 90 +++++++++ Sources/WSCore/WebSocketExtension.swift | 173 ------------------ .../WebSocketExtensionNegotiationTests.swift | 72 +++++++- 6 files changed, 352 insertions(+), 188 deletions(-) create mode 100644 Sources/WSCore/Extensions/WebSocketExtension.swift create mode 100644 Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift create mode 100644 Sources/WSCore/Extensions/WebSocketExtensionHTTPParameters.swift delete mode 100644 Sources/WSCore/WebSocketExtension.swift diff --git a/Sources/WSClient/WebSocketClientChannel.swift b/Sources/WSClient/WebSocketClientChannel.swift index 9fb8623..c53d700 100644 --- a/Sources/WSClient/WebSocketClientChannel.swift +++ b/Sources/WSClient/WebSocketClientChannel.swift @@ -50,16 +50,13 @@ struct WebSocketClientChannel: ClientConnectionChannel { let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) // work out what extensions we should add based off the server response let headerFields = HTTPFields(head.headers, splitCookie: false) - let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(headerFields) - if serverExtensions.count > 0 { + let extensions = try configuration.extensions.buildClientExtensions(from: headerFields) + if extensions.count > 0 { logger.debug( "Enabled extensions", - metadata: ["hb.ws.extensions": .string(serverExtensions.map(\.name).joined(separator: ","))] + metadata: ["hb.ws.extensions": .string(extensions.map(\.name).joined(separator: ","))] ) } - let extensions = try configuration.extensions.compactMap { - try $0.clientExtension(from: serverExtensions) - } return UpgradeResult.websocket(asyncChannel, extensions) } } @@ -71,7 +68,10 @@ struct WebSocketClientChannel: ClientConnectionChannel { let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) headers.add(contentsOf: additionalHeaders) // add websocket extensions to headers - headers.add(contentsOf: self.configuration.extensions.map { (name: "Sec-WebSocket-Extensions", value: $0.clientRequestHeader()) }) + headers.add(contentsOf: self.configuration.extensions.compactMap { + let requestHeaders = $0.clientRequestHeader() + return requestHeaders != "" ? ("Sec-WebSocket-Extensions", requestHeaders) : nil + }) let requestHead = HTTPRequestHead( version: .http1_1, diff --git a/Sources/WSCore/Extensions/WebSocketExtension.swift b/Sources/WSCore/Extensions/WebSocketExtension.swift new file mode 100644 index 0000000..c333df0 --- /dev/null +++ b/Sources/WSCore/Extensions/WebSocketExtension.swift @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation +import HTTPTypes +import Logging +import NIOCore +import NIOWebSocket + +/// Basic context implementation of ``WebSocketContext``. +public struct WebSocketExtensionContext: Sendable { + public let logger: Logger + + init(logger: Logger) { + self.logger = logger + } +} + +/// Protocol for WebSocket extension +public protocol WebSocketExtension: Sendable { + /// Extension name + var name: String { get } + /// Process frame received from websocket + func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame + /// Process frame about to be sent to websocket + func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame + /// shutdown extension + func shutdown() async +} diff --git a/Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift b/Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift new file mode 100644 index 0000000..1cb6c9e --- /dev/null +++ b/Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift @@ -0,0 +1,151 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import HTTPTypes + +/// Protocol for WebSocket extension builder +public protocol WebSocketExtensionBuilder: Sendable { + /// name of WebSocket extension name + static var name: String { get } + /// construct client request header + func clientRequestHeader() -> String + /// construct server response header based of client request + func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? + /// construct server version of extension based of client request + func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? + /// construct client version of extension based of server response + func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? +} + +extension WebSocketExtensionBuilder { + /// construct server response header based of all client requests + public func serverResponseHeader(to requests: [WebSocketExtensionHTTPParameters]) -> String? { + for request in requests { + guard request.name == Self.name else { continue } + if let response = serverReponseHeader(to: request) { + return response + } + } + return nil + } + + /// construct all server extensions based of all client requests + public func serverExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { + for request in requests { + guard request.name == Self.name else { continue } + if let ext = try serverExtension(from: request) { + return ext + } + } + if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol { + return nonNegotiableExtensionBuilder.build() + } + return nil + } + + /// construct all client extensions based of all server responses + public func clientExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { + for request in requests { + guard request.name == Self.name else { continue } + if let ext = try clientExtension(from: request) { + return ext + } + } + if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol { + return nonNegotiableExtensionBuilder.build() + } + return nil + } +} + +/// Protocol for w WebSocket extension that is applied without any negotiation with the other side +protocol _WebSocketNonNegotiableExtensionBuilderProtocol: WebSocketExtensionBuilder { + associatedtype Extension: WebSocketExtension + func build() -> Extension +} + +/// A WebSocket extension that is applied without any negotiation with the other side +public struct WebSocketNonNegotiableExtensionBuilder: _WebSocketNonNegotiableExtensionBuilderProtocol { + public static var name: String { String(describing: type(of: Extension.self)) } + + let _build: @Sendable () -> Extension + + init(_ build: @escaping @Sendable () -> Extension) { + self._build = build + } + + public func build() -> Extension { + self._build() + } +} + +extension WebSocketNonNegotiableExtensionBuilder { + /// construct client request header + public func clientRequestHeader() -> String { "" } + /// construct server response header based of client request + public func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? { nil } + /// construct server version of extension based of client request + public func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() } + /// construct client version of extension based of server response + public func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() } +} + +extension Array { + /// Build client extensions from response from WebSocket server + /// - Parameter responseHeaders: Server response headers + /// - Returns: Array of client extensions to enable + public func buildClientExtensions(from responseHeaders: HTTPFields) throws -> [any WebSocketExtension] { + let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(responseHeaders) + return try self.compactMap { + try $0.clientExtension(from: serverExtensions) + } + } + + /// Do the client/server WebSocket negotiation based off headers received from the client. + /// - Parameter requestHeaders: Client request headers + /// - Returns: Headers to pass back to client and array of server extensions to enable + public func serverExtensionNegotiation(requestHeaders: HTTPFields) throws -> (HTTPFields, [any WebSocketExtension]) { + var responseHeaders: HTTPFields = .init() + let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(requestHeaders) + let extensionResponseHeaders = self.compactMap { $0.serverResponseHeader(to: clientHeaders) } + responseHeaders.append(contentsOf: extensionResponseHeaders.map { .init(name: .secWebSocketExtensions, value: $0) }) + let extensions = try self.compactMap { + try $0.serverExtension(from: clientHeaders) + } + return (responseHeaders, extensions) + } +} + +/// Build WebSocket extension builder +public struct WebSocketExtensionFactory: Sendable { + public let build: @Sendable () -> any WebSocketExtensionBuilder + + public init(_ build: @escaping @Sendable () -> any WebSocketExtensionBuilder) { + self.build = build + } + + /// Extension to be applied without negotiation with the other side. + /// + /// Most extensions involve some form of negotiation between the client and the server + /// to decide on whether they should be applied and with what parameters. This extension + /// builder is for the situation where no negotiation is needed or that negotiation has + /// already occurred. + /// + /// - Parameter build: closure creating extension + /// - Returns: WebSocketExtensionFactory + public static func nonNegotiatedExtension(_ build: @escaping @Sendable () -> some WebSocketExtension) -> Self { + return .init { + WebSocketNonNegotiableExtensionBuilder(build) + } + } +} diff --git a/Sources/WSCore/Extensions/WebSocketExtensionHTTPParameters.swift b/Sources/WSCore/Extensions/WebSocketExtensionHTTPParameters.swift new file mode 100644 index 0000000..1305b7e --- /dev/null +++ b/Sources/WSCore/Extensions/WebSocketExtensionHTTPParameters.swift @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes + +/// Parsed parameters from `Sec-WebSocket-Extensions` header +public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { + /// A single parameter + public enum Parameter: Sendable, Equatable { + // Parameter with a value + case value(String) + // Parameter with no value + case null + + // Convert to optional + public var optional: String? { + switch self { + case .value(let string): + return .some(string) + case .null: + return .none + } + } + + // Convert to integer + public var integer: Int? { + switch self { + case .value(let string): + return Int(string) + case .null: + return .none + } + } + } + + public let parameters: [String: Parameter] + public let name: String + + /// initialise WebSocket extension parameters from string + init?(from header: some StringProtocol) { + let split = header.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }[...] + if let name = split.first { + self.name = name + } else { + return nil + } + var index = split.index(after: split.startIndex) + var parameters: [String: Parameter] = [:] + while index != split.endIndex { + let keyValue = split[index].split(separator: "=", maxSplits: 1).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } + if let key = keyValue.first { + if keyValue.count > 1 { + parameters[key] = .value(keyValue[1]) + } else { + parameters[key] = .null + } + } + index = split.index(after: index) + } + self.parameters = parameters + } + + /// Parse all `Sec-WebSocket-Extensions` header values + /// - Parameters: + /// - headers: headers coming from other + /// - Returns: Array of extensions + public static func parseHeaders(_ headers: HTTPFields) -> [WebSocketExtensionHTTPParameters] { + let extHeaders = headers[values: .secWebSocketExtensions] + return extHeaders.compactMap { .init(from: $0) } + } +} + +extension WebSocketExtensionHTTPParameters { + /// Initialiser used by tests + package init(_ name: String, parameters: [String: Parameter]) { + self.name = name + self.parameters = parameters + } +} diff --git a/Sources/WSCore/WebSocketExtension.swift b/Sources/WSCore/WebSocketExtension.swift deleted file mode 100644 index 248e07e..0000000 --- a/Sources/WSCore/WebSocketExtension.swift +++ /dev/null @@ -1,173 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Hummingbird server framework project -// -// Copyright (c) 2023 the Hummingbird authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Foundation -import HTTPTypes -import Logging -import NIOCore -import NIOWebSocket - -/// Basic context implementation of ``WebSocketContext``. -public struct WebSocketExtensionContext: Sendable { - public let logger: Logger - - init(logger: Logger) { - self.logger = logger - } -} - -/// Protocol for WebSocket extension -public protocol WebSocketExtension: Sendable { - /// Extension name - var name: String { get } - /// Process frame received from websocket - func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame - /// Process frame about to be sent to websocket - func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame - /// shutdown extension - func shutdown() async -} - -/// Protocol for WebSocket extension builder -public protocol WebSocketExtensionBuilder: Sendable { - /// name of WebSocket extension name - static var name: String { get } - /// construct client request header - func clientRequestHeader() -> String - /// construct server response header based of client request - func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? - /// construct server version of extension based of client request - func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? - /// construct client version of extension based of server response - func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? -} - -extension WebSocketExtensionBuilder { - /// construct server response header based of all client requests - public func serverResponseHeader(to requests: [WebSocketExtensionHTTPParameters]) -> String? { - for request in requests { - guard request.name == Self.name else { continue } - if let response = serverReponseHeader(to: request) { - return response - } - } - return nil - } - - /// construct all server extensions based of all client requests - public func serverExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { - for request in requests { - guard request.name == Self.name else { continue } - if let ext = try serverExtension(from: request) { - return ext - } - } - return nil - } - - /// construct all client extensions based of all server responses - public func clientExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { - for request in requests { - guard request.name == Self.name else { continue } - if let ext = try clientExtension(from: request) { - return ext - } - } - return nil - } -} - -/// Build WebSocket extension builder -public struct WebSocketExtensionFactory: Sendable { - public let build: @Sendable () -> any WebSocketExtensionBuilder - - public init(_ build: @escaping @Sendable () -> any WebSocketExtensionBuilder) { - self.build = build - } -} - -/// Parsed parameters from `Sec-WebSocket-Extensions` header -public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { - /// A single parameter - public enum Parameter: Sendable, Equatable { - // Parameter with a value - case value(String) - // Parameter with no value - case null - - // Convert to optional - public var optional: String? { - switch self { - case .value(let string): - return .some(string) - case .null: - return .none - } - } - - // Convert to integer - public var integer: Int? { - switch self { - case .value(let string): - return Int(string) - case .null: - return .none - } - } - } - - public let parameters: [String: Parameter] - public let name: String - - /// initialise WebSocket extension parameters from string - init?(from header: some StringProtocol) { - let split = header.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }[...] - if let name = split.first { - self.name = name - } else { - return nil - } - var index = split.index(after: split.startIndex) - var parameters: [String: Parameter] = [:] - while index != split.endIndex { - let keyValue = split[index].split(separator: "=", maxSplits: 1).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } - if let key = keyValue.first { - if keyValue.count > 1 { - parameters[key] = .value(keyValue[1]) - } else { - parameters[key] = .null - } - } - index = split.index(after: index) - } - self.parameters = parameters - } - - /// Parse all `Sec-WebSocket-Extensions` header values - /// - Parameters: - /// - headers: headers coming from other - /// - Returns: Array of extensions - public static func parseHeaders(_ headers: HTTPFields) -> [WebSocketExtensionHTTPParameters] { - let extHeaders = headers[values: .secWebSocketExtensions] - return extHeaders.compactMap { .init(from: $0) } - } -} - -extension WebSocketExtensionHTTPParameters { - /// Initialiser used by tests - package init(_ name: String, parameters: [String: Parameter]) { - self.name = name - self.parameters = parameters - } -} diff --git a/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift b/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift index 053c923..01071c1 100644 --- a/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift +++ b/Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import HTTPTypes +import NIOWebSocket @testable import WSCompression @testable import WSCore import XCTest @@ -63,16 +64,71 @@ final class WebSocketExtensionNegotiationTests: XCTestCase { ) } - func testUnregonisedExtensionServerResponse() { - let requestHeaders: [WebSocketExtensionHTTPParameters] = [ - .init("permessage-foo", parameters: ["bar": .value("baz")]), - .init("permessage-deflate", parameters: ["client_max_window_bits": .value("10")]), - ] - let ext = PerMessageDeflateExtensionBuilder() - let serverResponse = ext.serverResponseHeader(to: requestHeaders) + func testUnregonisedExtensionServerResponse() throws { + let serverExtensions: [WebSocketExtensionBuilder] = [PerMessageDeflateExtensionBuilder()] + let (headers, extensions) = try serverExtensions.serverExtensionNegotiation( + requestHeaders: [ + .secWebSocketExtensions: "permessage-foo;bar=baz", + .secWebSocketExtensions: "permessage-deflate;client_max_window_bits=10", + ] + ) XCTAssertEqual( - serverResponse, + headers[.secWebSocketExtensions], "permessage-deflate;client_max_window_bits=10" ) + XCTAssertEqual(extensions.count, 1) + let firstExtension = try XCTUnwrap(extensions.first) + XCTAssert(firstExtension is PerMessageDeflateExtension) + + let requestExtensions = try serverExtensions.buildClientExtensions(from: headers) + XCTAssertEqual(requestExtensions.count, 1) + XCTAssert(requestExtensions[0] is PerMessageDeflateExtension) + } + + func testNonNegotiableClientExtension() throws { + struct MyExtension: WebSocketExtension { + var name = "my-extension" + + func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { + return frame + } + + func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { + return frame + } + + func shutdown() async {} + } + let clientExtensionBuilders: [WebSocketExtensionBuilder] = [WebSocketExtensionFactory.nonNegotiatedExtension { + MyExtension() + }.build()] + let clientExtensions = try clientExtensionBuilders.buildClientExtensions(from: [:]) + XCTAssertEqual(clientExtensions.count, 1) + let myExtension = try XCTUnwrap(clientExtensions.first) + XCTAssert(myExtension is MyExtension) + } + + func testNonNegotiableServerExtension() throws { + struct MyExtension: WebSocketExtension { + var name = "my-extension" + + func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { + return frame + } + + func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { + return frame + } + + func shutdown() async {} + } + let serverExtensionBuilders: [WebSocketExtensionBuilder] = [WebSocketNonNegotiableExtensionBuilder { MyExtension() }] + let (headers, serverExtensions) = try serverExtensionBuilders.serverExtensionNegotiation( + requestHeaders: [:] + ) + XCTAssertEqual(headers.count, 0) + XCTAssertEqual(serverExtensions.count, 1) + let myExtension = try XCTUnwrap(serverExtensions.first) + XCTAssert(myExtension is MyExtension) } }