-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added WebSocketNonNegotiableExtensionBuilder (#5)
* Added WebSocketNonNegotiableExtensionBuilder * Fix test * Make WebSocketNonNegotiableExtensionBuilder a struct * nonNegociatedExtension * Add comment for nonNegotiatedExtension
- Loading branch information
1 parent
e34cb2f
commit 8f8dd9f
Showing
6 changed files
with
352 additions
and
188 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the Hummingbird server framework project | ||
// | ||
// Copyright (c) 2023-2024 the Hummingbird authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
import Foundation | ||
import HTTPTypes | ||
import Logging | ||
import NIOCore | ||
import NIOWebSocket | ||
|
||
/// Basic context implementation of ``WebSocketContext``. | ||
public struct WebSocketExtensionContext: Sendable { | ||
public let logger: Logger | ||
|
||
init(logger: Logger) { | ||
self.logger = logger | ||
} | ||
} | ||
|
||
/// Protocol for WebSocket extension | ||
public protocol WebSocketExtension: Sendable { | ||
/// Extension name | ||
var name: String { get } | ||
/// Process frame received from websocket | ||
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame | ||
/// Process frame about to be sent to websocket | ||
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame | ||
/// shutdown extension | ||
func shutdown() async | ||
} |
151 changes: 151 additions & 0 deletions
151
Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the Hummingbird server framework project | ||
// | ||
// Copyright (c) 2023-2024 the Hummingbird authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
import HTTPTypes | ||
|
||
/// Protocol for WebSocket extension builder | ||
public protocol WebSocketExtensionBuilder: Sendable { | ||
/// name of WebSocket extension name | ||
static var name: String { get } | ||
/// construct client request header | ||
func clientRequestHeader() -> String | ||
/// construct server response header based of client request | ||
func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? | ||
/// construct server version of extension based of client request | ||
func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? | ||
/// construct client version of extension based of server response | ||
func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? | ||
} | ||
|
||
extension WebSocketExtensionBuilder { | ||
/// construct server response header based of all client requests | ||
public func serverResponseHeader(to requests: [WebSocketExtensionHTTPParameters]) -> String? { | ||
for request in requests { | ||
guard request.name == Self.name else { continue } | ||
if let response = serverReponseHeader(to: request) { | ||
return response | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
/// construct all server extensions based of all client requests | ||
public func serverExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { | ||
for request in requests { | ||
guard request.name == Self.name else { continue } | ||
if let ext = try serverExtension(from: request) { | ||
return ext | ||
} | ||
} | ||
if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol { | ||
return nonNegotiableExtensionBuilder.build() | ||
} | ||
return nil | ||
} | ||
|
||
/// construct all client extensions based of all server responses | ||
public func clientExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { | ||
for request in requests { | ||
guard request.name == Self.name else { continue } | ||
if let ext = try clientExtension(from: request) { | ||
return ext | ||
} | ||
} | ||
if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol { | ||
return nonNegotiableExtensionBuilder.build() | ||
} | ||
return nil | ||
} | ||
} | ||
|
||
/// Protocol for w WebSocket extension that is applied without any negotiation with the other side | ||
protocol _WebSocketNonNegotiableExtensionBuilderProtocol: WebSocketExtensionBuilder { | ||
associatedtype Extension: WebSocketExtension | ||
func build() -> Extension | ||
} | ||
|
||
/// A WebSocket extension that is applied without any negotiation with the other side | ||
public struct WebSocketNonNegotiableExtensionBuilder<Extension: WebSocketExtension>: _WebSocketNonNegotiableExtensionBuilderProtocol { | ||
public static var name: String { String(describing: type(of: Extension.self)) } | ||
|
||
let _build: @Sendable () -> Extension | ||
|
||
init(_ build: @escaping @Sendable () -> Extension) { | ||
self._build = build | ||
} | ||
|
||
public func build() -> Extension { | ||
self._build() | ||
} | ||
} | ||
|
||
extension WebSocketNonNegotiableExtensionBuilder { | ||
/// construct client request header | ||
public func clientRequestHeader() -> String { "" } | ||
/// construct server response header based of client request | ||
public func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? { nil } | ||
/// construct server version of extension based of client request | ||
public func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() } | ||
/// construct client version of extension based of server response | ||
public func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() } | ||
} | ||
|
||
extension Array<any WebSocketExtensionBuilder> { | ||
/// Build client extensions from response from WebSocket server | ||
/// - Parameter responseHeaders: Server response headers | ||
/// - Returns: Array of client extensions to enable | ||
public func buildClientExtensions(from responseHeaders: HTTPFields) throws -> [any WebSocketExtension] { | ||
let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(responseHeaders) | ||
return try self.compactMap { | ||
try $0.clientExtension(from: serverExtensions) | ||
} | ||
} | ||
|
||
/// Do the client/server WebSocket negotiation based off headers received from the client. | ||
/// - Parameter requestHeaders: Client request headers | ||
/// - Returns: Headers to pass back to client and array of server extensions to enable | ||
public func serverExtensionNegotiation(requestHeaders: HTTPFields) throws -> (HTTPFields, [any WebSocketExtension]) { | ||
var responseHeaders: HTTPFields = .init() | ||
let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(requestHeaders) | ||
let extensionResponseHeaders = self.compactMap { $0.serverResponseHeader(to: clientHeaders) } | ||
responseHeaders.append(contentsOf: extensionResponseHeaders.map { .init(name: .secWebSocketExtensions, value: $0) }) | ||
let extensions = try self.compactMap { | ||
try $0.serverExtension(from: clientHeaders) | ||
} | ||
return (responseHeaders, extensions) | ||
} | ||
} | ||
|
||
/// Build WebSocket extension builder | ||
public struct WebSocketExtensionFactory: Sendable { | ||
public let build: @Sendable () -> any WebSocketExtensionBuilder | ||
|
||
public init(_ build: @escaping @Sendable () -> any WebSocketExtensionBuilder) { | ||
self.build = build | ||
} | ||
|
||
/// Extension to be applied without negotiation with the other side. | ||
/// | ||
/// Most extensions involve some form of negotiation between the client and the server | ||
/// to decide on whether they should be applied and with what parameters. This extension | ||
/// builder is for the situation where no negotiation is needed or that negotiation has | ||
/// already occurred. | ||
/// | ||
/// - Parameter build: closure creating extension | ||
/// - Returns: WebSocketExtensionFactory | ||
public static func nonNegotiatedExtension(_ build: @escaping @Sendable () -> some WebSocketExtension) -> Self { | ||
return .init { | ||
WebSocketNonNegotiableExtensionBuilder(build) | ||
} | ||
} | ||
} |
90 changes: 90 additions & 0 deletions
90
Sources/WSCore/Extensions/WebSocketExtensionHTTPParameters.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the Hummingbird server framework project | ||
// | ||
// Copyright (c) 2023-2024 the Hummingbird authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
import HTTPTypes | ||
|
||
/// Parsed parameters from `Sec-WebSocket-Extensions` header | ||
public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { | ||
/// A single parameter | ||
public enum Parameter: Sendable, Equatable { | ||
// Parameter with a value | ||
case value(String) | ||
// Parameter with no value | ||
case null | ||
|
||
// Convert to optional | ||
public var optional: String? { | ||
switch self { | ||
case .value(let string): | ||
return .some(string) | ||
case .null: | ||
return .none | ||
} | ||
} | ||
|
||
// Convert to integer | ||
public var integer: Int? { | ||
switch self { | ||
case .value(let string): | ||
return Int(string) | ||
case .null: | ||
return .none | ||
} | ||
} | ||
} | ||
|
||
public let parameters: [String: Parameter] | ||
public let name: String | ||
|
||
/// initialise WebSocket extension parameters from string | ||
init?(from header: some StringProtocol) { | ||
let split = header.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }[...] | ||
if let name = split.first { | ||
self.name = name | ||
} else { | ||
return nil | ||
} | ||
var index = split.index(after: split.startIndex) | ||
var parameters: [String: Parameter] = [:] | ||
while index != split.endIndex { | ||
let keyValue = split[index].split(separator: "=", maxSplits: 1).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } | ||
if let key = keyValue.first { | ||
if keyValue.count > 1 { | ||
parameters[key] = .value(keyValue[1]) | ||
} else { | ||
parameters[key] = .null | ||
} | ||
} | ||
index = split.index(after: index) | ||
} | ||
self.parameters = parameters | ||
} | ||
|
||
/// Parse all `Sec-WebSocket-Extensions` header values | ||
/// - Parameters: | ||
/// - headers: headers coming from other | ||
/// - Returns: Array of extensions | ||
public static func parseHeaders(_ headers: HTTPFields) -> [WebSocketExtensionHTTPParameters] { | ||
let extHeaders = headers[values: .secWebSocketExtensions] | ||
return extHeaders.compactMap { .init(from: $0) } | ||
} | ||
} | ||
|
||
extension WebSocketExtensionHTTPParameters { | ||
/// Initialiser used by tests | ||
package init(_ name: String, parameters: [String: Parameter]) { | ||
self.name = name | ||
self.parameters = parameters | ||
} | ||
} |
Oops, something went wrong.