Skip to content

Commit

Permalink
Added initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitribouniol committed Jan 6, 2024
1 parent f65a90d commit 92a019d
Show file tree
Hide file tree
Showing 9 changed files with 446 additions and 2 deletions.
56 changes: 56 additions & 0 deletions Sources/HostRouter/Application+HostRouter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import Vapor

extension Application {
private struct HostRouterKey: StorageKey {
typealias Value = HostRouter
}

private struct HostRoutesKey: StorageKey {
typealias Value = HostRoutes
}

/// The application's host router.
public var hostRouter: HostRouter { installHostRouter() }

/// Install the host router to enable host-routing.
///
/// This method is usually called automatically once a host route is added, though it can be called manually before any routes are added if you wish to configure the middleware to be before or after other middleware you may need to install.
@discardableResult
public func installHostRouter(at position: Middlewares.Position = .end) -> HostRouter {
if let router = self.storage[HostRouterKey.self] {
return router
}

let router = HostRouter()
self.storage[HostRouterKey.self] = router
self.middleware.use(router, at: position)
return router
}

var hostRoutes: HostRoutes {
get {
if let existing = self.storage[HostRoutesKey.self] {
return existing
} else {
let new = HostRoutes()
self.storage[HostRoutesKey.self] = new
return new
}
}
set {
self.storage[HostRoutesKey.self] = newValue
}
}
}

extension Application: TopLevelHostRoutesBuilder {
public func add(_ route: HostRoute) {
self.installHostRouter()
self.hostRoutes.add(route)
}

// TODO: Vapor needs support for multiple ports before this can be enabled.
internal func grouped(port: Int) -> some TopLevelHostRoutesBuilder {
HostRoutesGroup(root: self, port: port, domainPath: [])
}
}
66 changes: 66 additions & 0 deletions Sources/HostRouter/HostComponent.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import Vapor

public enum HostComponent: Sendable, Hashable, Comparable {
case constant(String)
case parameter(String)
case anything
case catchall
}

extension HostComponent: ExpressibleByStringInterpolation {
public init(stringLiteral value: String) {
if value.hasPrefix(":") {
self = .parameter(.init(value.dropFirst()))
} else if value == "*" {
self = .anything
} else if value == "**" {
self = .catchall
} else {
self = .constant(value)
}
}
}

extension HostComponent: CustomStringConvertible {
public var description: String {
switch self {
case .anything:
return "*"
case .catchall:
return "**"
case .parameter(let name):
return ":" + name
case .constant(let constant):
return constant
}
}
}

extension StringProtocol {
/// Converts a host (either a domain or an IP address) into a reversed collection of ``HostComponent``s.
public var reversedDomainComponents: [HostComponent] {
guard !self.isIPAddress else { return [.constant(String(self))] }
return self.split(separator: ".").reversed().map { .init(stringLiteral: $0.lowercased()) }
}

/// Retrieve the port and domain from the receiving Host string.
var hostComponents: (port: String?, reverseDomain: [String]) {
let baseComponents = self.split(separator: ":")
let (host, port) = baseComponents.count == 2 ? (String(baseComponents[0]), String(baseComponents[1])) : (String(self), nil)

return (port, host.split(separator: ".").reversed().map(String.init))
}

/// Check if the string is an IP address.
private var isIPAddress: Bool {
// TODO: This was unused in Vapor, and NIO may have a better helper for this.

/// We need some scratch space to let inet_pton write into.
var ipv4Addr = in_addr()
var ipv6Addr = in6_addr()

return self.withCString { ptr in
return inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || inet_pton(AF_INET6, ptr, &ipv6Addr) == 1
}
}
}
14 changes: 14 additions & 0 deletions Sources/HostRouter/HostRoute.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import Vapor

public struct HostRoute {
/// The port we are routed under.
///
/// Internal until Vapor properly supports multiple ports.
internal var port: Int?

/// The domain path, in reversed order.
public var domainPath: [HostComponent]

/// The route to follow after the host has been routed.
public var route: Route
}
18 changes: 18 additions & 0 deletions Sources/HostRouter/HostRouteCollection.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/// Groups collections of routes together for adding to a router.
public protocol HostRouteCollection {
/// Registers routes to the incoming router.
///
/// - parameters:
/// - routes: `HostRoutesBuilder` to register any new routes to.
func boot(routes: some HostRoutesBuilder) throws
}

extension HostRoutesBuilder {
/// Registers all of the routes in the group to this router.
///
/// - parameters:
/// - collection: `HostRouteCollection` to register.
public func register(collection: some HostRouteCollection) throws {
try collection.boot(routes: self)
}
}
19 changes: 17 additions & 2 deletions Sources/HostRouter/HostRouter.swift
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
// The Swift Programming Language
// https://docs.swift.org/swift-book
import Vapor

/// A middleware that will group routes it owns under the specified host or subdomain.
///
/// Heavily inspired by https://github.com/vapor/vapor/issues/2745#issuecomment-1450795410 and Vapor's internal DefaultResponder.
public struct HostRouter: AsyncMiddleware {
public func respond(to request: Vapor.Request, chainingTo next: Vapor.AsyncResponder) async throws -> Vapor.Response {

let hostRoutes = request.application.hostRoutes

if let cachedRoute = hostRoutes.route(for: request) {
return try await cachedRoute.route.responder.respond(to: request).get()
}

return try await next.respond(to: request)
}
}
137 changes: 137 additions & 0 deletions Sources/HostRouter/HostRoutes.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import Vapor

typealias HostPrefix = [PathComponent]

struct HostRoutes {
var routes: [HostRoute] = []
var routerMap: [HostPrefix : TrieRouter<HostRoute>] = [:]
private var hostRouter = TrieRouter(HostPrefix.self, options: [])
private var catchallHostRouter = TrieRouter(HostPrefix.self, options: [])

mutating func add(_ hostRoute: HostRoute) {
routes.append(hostRoute)

let registeredPrefix = hostRoute.prefix
if routerMap[registeredPrefix] == nil {
routerMap[registeredPrefix] = TrieRouter()
hostRouter.register(registeredPrefix, at: registeredPrefix)
if registeredPrefix.contains(.catchall) {
catchallHostRouter.register(registeredPrefix, at: registeredPrefix)
}
}

let resourceRoute = hostRoute.route

/// Remove any empty path components
let path = resourceRoute.path.filter { component in
switch component {
case .constant(let string):
return !string.isEmpty
default:
return true
}
}

/// If the route isn't explicitly a HEAD route, and it's made up solely of .constant components, register a HEAD route with the same path.
if resourceRoute.method == .GET, resourceRoute.path.allSatisfy(\.isConstant) {
let headRoute = Route(
method: .HEAD,
path: resourceRoute.path,
responder: HeadResponder(),
requestType: resourceRoute.requestType,
responseType: resourceRoute.responseType)

var hostHeadRoute = hostRoute
hostHeadRoute.route = headRoute

routerMap[registeredPrefix]?.register(hostHeadRoute, at: [.constant(HTTPMethod.HEAD.string)] + path)
}

routerMap[registeredPrefix]?.register(hostRoute, at: [.constant(resourceRoute.method.string)] + path)
}

func resourceRouter(for host: String, defaultPort: String, router: TrieRouter<HostPrefix>) -> (TrieRouter<HostRoute>, Parameters)? {
let (port, domainPath) = host.hostComponents

var hostParameters = Parameters()
hostParameters.set("port", to: port ?? defaultPort)

guard
let hostPrefix = router.route(path: [port ?? defaultPort] + domainPath, parameters: &hostParameters),
let resourceRouter = routerMap[hostPrefix]
else { return nil }

return (resourceRouter, hostParameters)
}

func route(for request: Request, router: TrieRouter<HostPrefix>? = nil) -> HostRoute? {
let defaultPort = request.application.http.server.shared.configuration.tlsConfiguration != nil ? "443" : "80"

guard
let host = request.headers[.host].first,
let (resourceRouter, hostParameters) = resourceRouter(for: host, defaultPort: defaultPort, router: router ?? hostRouter)
else { return nil }

let resourceComponents = request.url.path
.split(separator: "/")
.map(String.init)

/// If it's a HEAD request and a HEAD route exists, return that route...
if request.method == .HEAD, let hostRoute = resourceRouter.route(
path: [HTTPMethod.HEAD.string] + resourceComponents,
parameters: &request.parameters
) {
request.route = hostRoute.route
request.hostRoute = hostRoute
request.hostParameters = hostParameters
return hostRoute
}

/// ...otherwise forward HEAD requests to GET route
let method = (request.method == .HEAD) ? .GET : request.method

if let hostRoute = resourceRouter.route(
path: [method.string] + resourceComponents,
parameters: &request.parameters
) {
request.route = hostRoute.route
request.hostRoute = hostRoute
request.hostParameters = hostParameters
return hostRoute
}

if router == nil {
// Not a perfect solution, but should identify routes on a catchall host. A second version might use a boundary with a single path combining domain and resource paths, but that can be explored if this is not enough
return route(for: request, router: catchallHostRouter)
}

return nil
}
}

private extension HostRoute {
var prefix: [PathComponent] {
[port.map { .constant(String($0)) } ?? .anything] + domainPath.compactMap { component in
switch component {
case .constant(let string) where string.isEmpty: nil
case .constant(let string): .constant(string.lowercased())
case .parameter(let string): .parameter(string)
case .anything: .anything
case .catchall: .catchall
}
}
}
}

private extension PathComponent {
var isConstant: Bool {
if case .constant = self { return true }
return false
}
}

private struct HeadResponder: Responder {
func respond(to request: Request) -> EventLoopFuture<Response> {
request.eventLoop.makeSucceededFuture(.init(status: .ok))
}
}
13 changes: 13 additions & 0 deletions Sources/HostRouter/HostRoutesBuilder.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Vapor

public protocol HostRoutesBuilder: RoutesBuilder {
func add(_ route: HostRoute)
}

public protocol TopLevelHostRoutesBuilder: HostRoutesBuilder {}

extension HostRoutesBuilder {
public func add(_ route: Route) {
add(HostRoute(port: nil, domainPath: [], route: route))
}
}
61 changes: 61 additions & 0 deletions Sources/HostRouter/HostRoutesGroup.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
struct HostRoutesGroup: TopLevelHostRoutesBuilder {
/// Router to cascade to.
let root: HostRoutesBuilder

/// Port override.
let port: Int?

/// Domain path prefix.
let domainPath: [HostComponent]

/// Creates a new routing group.
init(root: HostRoutesBuilder, port: Int?, domainPath: [HostComponent]) {
self.root = root
self.port = port
self.domainPath = domainPath
}

/// Prepend the root to the added route, and pass it up the chain.
func add(_ route: HostRoute) {
var route = route
if let port {
route.port = port
}
if !domainPath.isEmpty {
route.domainPath = domainPath + route.domainPath
}
root.add(route)
}
}

extension TopLevelHostRoutesBuilder {
public func grouped(host: some StringProtocol) -> some HostRoutesBuilder {
HostRoutesGroup(root: self, port: nil, domainPath: host.reversedDomainComponents)
}
}

extension HostRoutesBuilder {
public func grouped(subDomain: some StringProtocol) -> some HostRoutesBuilder {
HostRoutesGroup(root: self, port: nil, domainPath: subDomain.reversedDomainComponents)
}

public func grouped(reverseDomain: HostComponent...) -> some HostRoutesBuilder {
HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain)
}

public func grouped(reverseDomain: [HostComponent]) -> some HostRoutesBuilder {
HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain)
}

public func grouped(subDomain: some StringProtocol, configure: (HostRoutesBuilder) throws -> ()) rethrows {
try configure(HostRoutesGroup(root: self, port: nil, domainPath: subDomain.reversedDomainComponents))
}

public func grouped(reverseDomain: HostComponent..., configure: (HostRoutesBuilder) throws -> ()) rethrows {
try configure(HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain))
}

public func grouped(reverseDomain: [HostComponent], configure: (HostRoutesBuilder) throws -> ()) rethrows {
try configure(HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain))
}
}
Loading

0 comments on commit 92a019d

Please sign in to comment.