Skip to content

Commit

Permalink
Added WebSocketNonNegotiableExtensionBuilder (#5)
Browse files Browse the repository at this point in the history
* Added WebSocketNonNegotiableExtensionBuilder

* Fix test

* Make WebSocketNonNegotiableExtensionBuilder a struct

* nonNegociatedExtension

* Add comment for nonNegotiatedExtension
  • Loading branch information
adam-fowler authored Nov 19, 2024
1 parent e34cb2f commit 8f8dd9f
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 188 deletions.
14 changes: 7 additions & 7 deletions Sources/WSClient/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,13 @@ struct WebSocketClientChannel: ClientConnectionChannel {
let asyncChannel = try NIOAsyncChannel<WebSocketFrame, WebSocketFrame>(wrappingChannelSynchronously: channel)
// work out what extensions we should add based off the server response
let headerFields = HTTPFields(head.headers, splitCookie: false)
let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(headerFields)
if serverExtensions.count > 0 {
let extensions = try configuration.extensions.buildClientExtensions(from: headerFields)
if extensions.count > 0 {
logger.debug(
"Enabled extensions",
metadata: ["hb.ws.extensions": .string(serverExtensions.map(\.name).joined(separator: ","))]
metadata: ["hb.ws.extensions": .string(extensions.map(\.name).joined(separator: ","))]
)
}
let extensions = try configuration.extensions.compactMap {
try $0.clientExtension(from: serverExtensions)
}
return UpgradeResult.websocket(asyncChannel, extensions)
}
}
Expand All @@ -71,7 +68,10 @@ struct WebSocketClientChannel: ClientConnectionChannel {
let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders)
headers.add(contentsOf: additionalHeaders)
// add websocket extensions to headers
headers.add(contentsOf: self.configuration.extensions.map { (name: "Sec-WebSocket-Extensions", value: $0.clientRequestHeader()) })
headers.add(contentsOf: self.configuration.extensions.compactMap {
let requestHeaders = $0.clientRequestHeader()
return requestHeaders != "" ? ("Sec-WebSocket-Extensions", requestHeaders) : nil
})

let requestHead = HTTPRequestHead(
version: .http1_1,
Expand Down
40 changes: 40 additions & 0 deletions Sources/WSCore/Extensions/WebSocketExtension.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Foundation
import HTTPTypes
import Logging
import NIOCore
import NIOWebSocket

/// Basic context implementation of ``WebSocketContext``.
public struct WebSocketExtensionContext: Sendable {
public let logger: Logger

init(logger: Logger) {
self.logger = logger
}
}

/// Protocol for WebSocket extension
public protocol WebSocketExtension: Sendable {
/// Extension name
var name: String { get }
/// Process frame received from websocket
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// Process frame about to be sent to websocket
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// shutdown extension
func shutdown() async
}
151 changes: 151 additions & 0 deletions Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import HTTPTypes

/// Protocol for WebSocket extension builder
public protocol WebSocketExtensionBuilder: Sendable {
/// name of WebSocket extension name
static var name: String { get }
/// construct client request header
func clientRequestHeader() -> String
/// construct server response header based of client request
func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String?
/// construct server version of extension based of client request
func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)?
/// construct client version of extension based of server response
func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)?
}

extension WebSocketExtensionBuilder {
/// construct server response header based of all client requests
public func serverResponseHeader(to requests: [WebSocketExtensionHTTPParameters]) -> String? {
for request in requests {
guard request.name == Self.name else { continue }
if let response = serverReponseHeader(to: request) {
return response
}
}
return nil
}

/// construct all server extensions based of all client requests
public func serverExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? {
for request in requests {
guard request.name == Self.name else { continue }
if let ext = try serverExtension(from: request) {
return ext
}
}
if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol {
return nonNegotiableExtensionBuilder.build()
}
return nil
}

/// construct all client extensions based of all server responses
public func clientExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? {
for request in requests {
guard request.name == Self.name else { continue }
if let ext = try clientExtension(from: request) {
return ext
}
}
if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol {
return nonNegotiableExtensionBuilder.build()
}
return nil
}
}

/// Protocol for w WebSocket extension that is applied without any negotiation with the other side
protocol _WebSocketNonNegotiableExtensionBuilderProtocol: WebSocketExtensionBuilder {
associatedtype Extension: WebSocketExtension
func build() -> Extension
}

/// A WebSocket extension that is applied without any negotiation with the other side
public struct WebSocketNonNegotiableExtensionBuilder<Extension: WebSocketExtension>: _WebSocketNonNegotiableExtensionBuilderProtocol {
public static var name: String { String(describing: type(of: Extension.self)) }

let _build: @Sendable () -> Extension

init(_ build: @escaping @Sendable () -> Extension) {
self._build = build
}

public func build() -> Extension {
self._build()
}
}

extension WebSocketNonNegotiableExtensionBuilder {
/// construct client request header
public func clientRequestHeader() -> String { "" }
/// construct server response header based of client request
public func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? { nil }
/// construct server version of extension based of client request
public func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() }
/// construct client version of extension based of server response
public func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() }
}

extension Array<any WebSocketExtensionBuilder> {
/// Build client extensions from response from WebSocket server
/// - Parameter responseHeaders: Server response headers
/// - Returns: Array of client extensions to enable
public func buildClientExtensions(from responseHeaders: HTTPFields) throws -> [any WebSocketExtension] {
let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(responseHeaders)
return try self.compactMap {
try $0.clientExtension(from: serverExtensions)
}
}

/// Do the client/server WebSocket negotiation based off headers received from the client.
/// - Parameter requestHeaders: Client request headers
/// - Returns: Headers to pass back to client and array of server extensions to enable
public func serverExtensionNegotiation(requestHeaders: HTTPFields) throws -> (HTTPFields, [any WebSocketExtension]) {
var responseHeaders: HTTPFields = .init()
let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(requestHeaders)
let extensionResponseHeaders = self.compactMap { $0.serverResponseHeader(to: clientHeaders) }
responseHeaders.append(contentsOf: extensionResponseHeaders.map { .init(name: .secWebSocketExtensions, value: $0) })
let extensions = try self.compactMap {
try $0.serverExtension(from: clientHeaders)
}
return (responseHeaders, extensions)
}
}

/// Build WebSocket extension builder
public struct WebSocketExtensionFactory: Sendable {
public let build: @Sendable () -> any WebSocketExtensionBuilder

public init(_ build: @escaping @Sendable () -> any WebSocketExtensionBuilder) {
self.build = build
}

/// Extension to be applied without negotiation with the other side.
///
/// Most extensions involve some form of negotiation between the client and the server
/// to decide on whether they should be applied and with what parameters. This extension
/// builder is for the situation where no negotiation is needed or that negotiation has
/// already occurred.
///
/// - Parameter build: closure creating extension
/// - Returns: WebSocketExtensionFactory
public static func nonNegotiatedExtension(_ build: @escaping @Sendable () -> some WebSocketExtension) -> Self {
return .init {
WebSocketNonNegotiableExtensionBuilder(build)
}
}
}
90 changes: 90 additions & 0 deletions Sources/WSCore/Extensions/WebSocketExtensionHTTPParameters.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import HTTPTypes

/// Parsed parameters from `Sec-WebSocket-Extensions` header
public struct WebSocketExtensionHTTPParameters: Sendable, Equatable {
/// A single parameter
public enum Parameter: Sendable, Equatable {
// Parameter with a value
case value(String)
// Parameter with no value
case null

// Convert to optional
public var optional: String? {
switch self {
case .value(let string):
return .some(string)
case .null:
return .none
}
}

// Convert to integer
public var integer: Int? {
switch self {
case .value(let string):
return Int(string)
case .null:
return .none
}
}
}

public let parameters: [String: Parameter]
public let name: String

/// initialise WebSocket extension parameters from string
init?(from header: some StringProtocol) {
let split = header.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }[...]
if let name = split.first {
self.name = name
} else {
return nil
}
var index = split.index(after: split.startIndex)
var parameters: [String: Parameter] = [:]
while index != split.endIndex {
let keyValue = split[index].split(separator: "=", maxSplits: 1).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }
if let key = keyValue.first {
if keyValue.count > 1 {
parameters[key] = .value(keyValue[1])
} else {
parameters[key] = .null
}
}
index = split.index(after: index)
}
self.parameters = parameters
}

/// Parse all `Sec-WebSocket-Extensions` header values
/// - Parameters:
/// - headers: headers coming from other
/// - Returns: Array of extensions
public static func parseHeaders(_ headers: HTTPFields) -> [WebSocketExtensionHTTPParameters] {
let extHeaders = headers[values: .secWebSocketExtensions]
return extHeaders.compactMap { .init(from: $0) }
}
}

extension WebSocketExtensionHTTPParameters {
/// Initialiser used by tests
package init(_ name: String, parameters: [String: Parameter]) {
self.name = name
self.parameters = parameters
}
}
Loading

0 comments on commit 8f8dd9f

Please sign in to comment.