-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f65a90d
commit 92a019d
Showing
9 changed files
with
446 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import Vapor | ||
|
||
extension Application { | ||
private struct HostRouterKey: StorageKey { | ||
typealias Value = HostRouter | ||
} | ||
|
||
private struct HostRoutesKey: StorageKey { | ||
typealias Value = HostRoutes | ||
} | ||
|
||
/// The application's host router. | ||
public var hostRouter: HostRouter { installHostRouter() } | ||
|
||
/// Install the host router to enable host-routing. | ||
/// | ||
/// This method is usually called automatically once a host route is added, though it can be called manually before any routes are added if you wish to configure the middleware to be before or after other middleware you may need to install. | ||
@discardableResult | ||
public func installHostRouter(at position: Middlewares.Position = .end) -> HostRouter { | ||
if let router = self.storage[HostRouterKey.self] { | ||
return router | ||
} | ||
|
||
let router = HostRouter() | ||
self.storage[HostRouterKey.self] = router | ||
self.middleware.use(router, at: position) | ||
return router | ||
} | ||
|
||
var hostRoutes: HostRoutes { | ||
get { | ||
if let existing = self.storage[HostRoutesKey.self] { | ||
return existing | ||
} else { | ||
let new = HostRoutes() | ||
self.storage[HostRoutesKey.self] = new | ||
return new | ||
} | ||
} | ||
set { | ||
self.storage[HostRoutesKey.self] = newValue | ||
} | ||
} | ||
} | ||
|
||
extension Application: TopLevelHostRoutesBuilder { | ||
public func add(_ route: HostRoute) { | ||
self.installHostRouter() | ||
self.hostRoutes.add(route) | ||
} | ||
|
||
// TODO: Vapor needs support for multiple ports before this can be enabled. | ||
internal func grouped(port: Int) -> some TopLevelHostRoutesBuilder { | ||
HostRoutesGroup(root: self, port: port, domainPath: []) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import Vapor | ||
|
||
public enum HostComponent: Sendable, Hashable, Comparable { | ||
case constant(String) | ||
case parameter(String) | ||
case anything | ||
case catchall | ||
} | ||
|
||
extension HostComponent: ExpressibleByStringInterpolation { | ||
public init(stringLiteral value: String) { | ||
if value.hasPrefix(":") { | ||
self = .parameter(.init(value.dropFirst())) | ||
} else if value == "*" { | ||
self = .anything | ||
} else if value == "**" { | ||
self = .catchall | ||
} else { | ||
self = .constant(value) | ||
} | ||
} | ||
} | ||
|
||
extension HostComponent: CustomStringConvertible { | ||
public var description: String { | ||
switch self { | ||
case .anything: | ||
return "*" | ||
case .catchall: | ||
return "**" | ||
case .parameter(let name): | ||
return ":" + name | ||
case .constant(let constant): | ||
return constant | ||
} | ||
} | ||
} | ||
|
||
extension StringProtocol { | ||
/// Converts a host (either a domain or an IP address) into a reversed collection of ``HostComponent``s. | ||
public var reversedDomainComponents: [HostComponent] { | ||
guard !self.isIPAddress else { return [.constant(String(self))] } | ||
return self.split(separator: ".").reversed().map { .init(stringLiteral: $0.lowercased()) } | ||
} | ||
|
||
/// Retrieve the port and domain from the receiving Host string. | ||
var hostComponents: (port: String?, reverseDomain: [String]) { | ||
let baseComponents = self.split(separator: ":") | ||
let (host, port) = baseComponents.count == 2 ? (String(baseComponents[0]), String(baseComponents[1])) : (String(self), nil) | ||
|
||
return (port, host.split(separator: ".").reversed().map(String.init)) | ||
} | ||
|
||
/// Check if the string is an IP address. | ||
private var isIPAddress: Bool { | ||
// TODO: This was unused in Vapor, and NIO may have a better helper for this. | ||
|
||
/// We need some scratch space to let inet_pton write into. | ||
var ipv4Addr = in_addr() | ||
var ipv6Addr = in6_addr() | ||
|
||
return self.withCString { ptr in | ||
return inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import Vapor | ||
|
||
public struct HostRoute { | ||
/// The port we are routed under. | ||
/// | ||
/// Internal until Vapor properly supports multiple ports. | ||
internal var port: Int? | ||
|
||
/// The domain path, in reversed order. | ||
public var domainPath: [HostComponent] | ||
|
||
/// The route to follow after the host has been routed. | ||
public var route: Route | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/// Groups collections of routes together for adding to a router. | ||
public protocol HostRouteCollection { | ||
/// Registers routes to the incoming router. | ||
/// | ||
/// - parameters: | ||
/// - routes: `HostRoutesBuilder` to register any new routes to. | ||
func boot(routes: some HostRoutesBuilder) throws | ||
} | ||
|
||
extension HostRoutesBuilder { | ||
/// Registers all of the routes in the group to this router. | ||
/// | ||
/// - parameters: | ||
/// - collection: `HostRouteCollection` to register. | ||
public func register(collection: some HostRouteCollection) throws { | ||
try collection.boot(routes: self) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,17 @@ | ||
// The Swift Programming Language | ||
// https://docs.swift.org/swift-book | ||
import Vapor | ||
|
||
/// A middleware that will group routes it owns under the specified host or subdomain. | ||
/// | ||
/// Heavily inspired by https://github.com/vapor/vapor/issues/2745#issuecomment-1450795410 and Vapor's internal DefaultResponder. | ||
public struct HostRouter: AsyncMiddleware { | ||
public func respond(to request: Vapor.Request, chainingTo next: Vapor.AsyncResponder) async throws -> Vapor.Response { | ||
|
||
let hostRoutes = request.application.hostRoutes | ||
|
||
if let cachedRoute = hostRoutes.route(for: request) { | ||
return try await cachedRoute.route.responder.respond(to: request).get() | ||
} | ||
|
||
return try await next.respond(to: request) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import Vapor | ||
|
||
typealias HostPrefix = [PathComponent] | ||
|
||
struct HostRoutes { | ||
var routes: [HostRoute] = [] | ||
var routerMap: [HostPrefix : TrieRouter<HostRoute>] = [:] | ||
private var hostRouter = TrieRouter(HostPrefix.self, options: []) | ||
private var catchallHostRouter = TrieRouter(HostPrefix.self, options: []) | ||
|
||
mutating func add(_ hostRoute: HostRoute) { | ||
routes.append(hostRoute) | ||
|
||
let registeredPrefix = hostRoute.prefix | ||
if routerMap[registeredPrefix] == nil { | ||
routerMap[registeredPrefix] = TrieRouter() | ||
hostRouter.register(registeredPrefix, at: registeredPrefix) | ||
if registeredPrefix.contains(.catchall) { | ||
catchallHostRouter.register(registeredPrefix, at: registeredPrefix) | ||
} | ||
} | ||
|
||
let resourceRoute = hostRoute.route | ||
|
||
/// Remove any empty path components | ||
let path = resourceRoute.path.filter { component in | ||
switch component { | ||
case .constant(let string): | ||
return !string.isEmpty | ||
default: | ||
return true | ||
} | ||
} | ||
|
||
/// If the route isn't explicitly a HEAD route, and it's made up solely of .constant components, register a HEAD route with the same path. | ||
if resourceRoute.method == .GET, resourceRoute.path.allSatisfy(\.isConstant) { | ||
let headRoute = Route( | ||
method: .HEAD, | ||
path: resourceRoute.path, | ||
responder: HeadResponder(), | ||
requestType: resourceRoute.requestType, | ||
responseType: resourceRoute.responseType) | ||
|
||
var hostHeadRoute = hostRoute | ||
hostHeadRoute.route = headRoute | ||
|
||
routerMap[registeredPrefix]?.register(hostHeadRoute, at: [.constant(HTTPMethod.HEAD.string)] + path) | ||
} | ||
|
||
routerMap[registeredPrefix]?.register(hostRoute, at: [.constant(resourceRoute.method.string)] + path) | ||
} | ||
|
||
func resourceRouter(for host: String, defaultPort: String, router: TrieRouter<HostPrefix>) -> (TrieRouter<HostRoute>, Parameters)? { | ||
let (port, domainPath) = host.hostComponents | ||
|
||
var hostParameters = Parameters() | ||
hostParameters.set("port", to: port ?? defaultPort) | ||
|
||
guard | ||
let hostPrefix = router.route(path: [port ?? defaultPort] + domainPath, parameters: &hostParameters), | ||
let resourceRouter = routerMap[hostPrefix] | ||
else { return nil } | ||
|
||
return (resourceRouter, hostParameters) | ||
} | ||
|
||
func route(for request: Request, router: TrieRouter<HostPrefix>? = nil) -> HostRoute? { | ||
let defaultPort = request.application.http.server.shared.configuration.tlsConfiguration != nil ? "443" : "80" | ||
|
||
guard | ||
let host = request.headers[.host].first, | ||
let (resourceRouter, hostParameters) = resourceRouter(for: host, defaultPort: defaultPort, router: router ?? hostRouter) | ||
else { return nil } | ||
|
||
let resourceComponents = request.url.path | ||
.split(separator: "/") | ||
.map(String.init) | ||
|
||
/// If it's a HEAD request and a HEAD route exists, return that route... | ||
if request.method == .HEAD, let hostRoute = resourceRouter.route( | ||
path: [HTTPMethod.HEAD.string] + resourceComponents, | ||
parameters: &request.parameters | ||
) { | ||
request.route = hostRoute.route | ||
request.hostRoute = hostRoute | ||
request.hostParameters = hostParameters | ||
return hostRoute | ||
} | ||
|
||
/// ...otherwise forward HEAD requests to GET route | ||
let method = (request.method == .HEAD) ? .GET : request.method | ||
|
||
if let hostRoute = resourceRouter.route( | ||
path: [method.string] + resourceComponents, | ||
parameters: &request.parameters | ||
) { | ||
request.route = hostRoute.route | ||
request.hostRoute = hostRoute | ||
request.hostParameters = hostParameters | ||
return hostRoute | ||
} | ||
|
||
if router == nil { | ||
// Not a perfect solution, but should identify routes on a catchall host. A second version might use a boundary with a single path combining domain and resource paths, but that can be explored if this is not enough | ||
return route(for: request, router: catchallHostRouter) | ||
} | ||
|
||
return nil | ||
} | ||
} | ||
|
||
private extension HostRoute { | ||
var prefix: [PathComponent] { | ||
[port.map { .constant(String($0)) } ?? .anything] + domainPath.compactMap { component in | ||
switch component { | ||
case .constant(let string) where string.isEmpty: nil | ||
case .constant(let string): .constant(string.lowercased()) | ||
case .parameter(let string): .parameter(string) | ||
case .anything: .anything | ||
case .catchall: .catchall | ||
} | ||
} | ||
} | ||
} | ||
|
||
private extension PathComponent { | ||
var isConstant: Bool { | ||
if case .constant = self { return true } | ||
return false | ||
} | ||
} | ||
|
||
private struct HeadResponder: Responder { | ||
func respond(to request: Request) -> EventLoopFuture<Response> { | ||
request.eventLoop.makeSucceededFuture(.init(status: .ok)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import Vapor | ||
|
||
public protocol HostRoutesBuilder: RoutesBuilder { | ||
func add(_ route: HostRoute) | ||
} | ||
|
||
public protocol TopLevelHostRoutesBuilder: HostRoutesBuilder {} | ||
|
||
extension HostRoutesBuilder { | ||
public func add(_ route: Route) { | ||
add(HostRoute(port: nil, domainPath: [], route: route)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
struct HostRoutesGroup: TopLevelHostRoutesBuilder { | ||
/// Router to cascade to. | ||
let root: HostRoutesBuilder | ||
|
||
/// Port override. | ||
let port: Int? | ||
|
||
/// Domain path prefix. | ||
let domainPath: [HostComponent] | ||
|
||
/// Creates a new routing group. | ||
init(root: HostRoutesBuilder, port: Int?, domainPath: [HostComponent]) { | ||
self.root = root | ||
self.port = port | ||
self.domainPath = domainPath | ||
} | ||
|
||
/// Prepend the root to the added route, and pass it up the chain. | ||
func add(_ route: HostRoute) { | ||
var route = route | ||
if let port { | ||
route.port = port | ||
} | ||
if !domainPath.isEmpty { | ||
route.domainPath = domainPath + route.domainPath | ||
} | ||
root.add(route) | ||
} | ||
} | ||
|
||
extension TopLevelHostRoutesBuilder { | ||
public func grouped(host: some StringProtocol) -> some HostRoutesBuilder { | ||
HostRoutesGroup(root: self, port: nil, domainPath: host.reversedDomainComponents) | ||
} | ||
} | ||
|
||
extension HostRoutesBuilder { | ||
public func grouped(subDomain: some StringProtocol) -> some HostRoutesBuilder { | ||
HostRoutesGroup(root: self, port: nil, domainPath: subDomain.reversedDomainComponents) | ||
} | ||
|
||
public func grouped(reverseDomain: HostComponent...) -> some HostRoutesBuilder { | ||
HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain) | ||
} | ||
|
||
public func grouped(reverseDomain: [HostComponent]) -> some HostRoutesBuilder { | ||
HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain) | ||
} | ||
|
||
public func grouped(subDomain: some StringProtocol, configure: (HostRoutesBuilder) throws -> ()) rethrows { | ||
try configure(HostRoutesGroup(root: self, port: nil, domainPath: subDomain.reversedDomainComponents)) | ||
} | ||
|
||
public func grouped(reverseDomain: HostComponent..., configure: (HostRoutesBuilder) throws -> ()) rethrows { | ||
try configure(HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain)) | ||
} | ||
|
||
public func grouped(reverseDomain: [HostComponent], configure: (HostRoutesBuilder) throws -> ()) rethrows { | ||
try configure(HostRoutesGroup(root: self, port: nil, domainPath: reverseDomain)) | ||
} | ||
} |
Oops, something went wrong.