Skip to content

Commit

Permalink
Add a DNS NameResolver (#5)
Browse files Browse the repository at this point in the history
Motivation:

Many users will rely on DNS to resolve the IP addresses of servers to
connect to, we should therefore provide a DNS name resolver.

Modifications:

- Add a DNS name resolver factory capable of resolving IP addresses
- Add the resolver to the registry defaults

Result:

Can resolve DNS targets
  • Loading branch information
glbrntt authored Sep 25, 2024
1 parent 2978160 commit ba557f2
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 2 deletions.
127 changes: 127 additions & 0 deletions Sources/GRPCNIOTransportCore/Client/Resolver/NameResolver+DNS.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2024, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

private import GRPCCore

extension ResolvableTargets {
/// A resolvable target for addresses which can be resolved via DNS.
///
/// If you already have an IPv4 or IPv6 address use ``ResolvableTargets/IPv4`` and
/// ``ResolvableTargets/IPv6`` respectively.
public struct DNS: ResolvableTarget, Sendable {
/// The host to resolve via DNS.
public var host: String

/// The port to use with resolved addresses.
public var port: Int

/// Create a new DNS target.
/// - Parameters:
/// - host: The host to resolve via DNS.
/// - port: The port to use with resolved addresses.
public init(host: String, port: Int) {
self.host = host
self.port = port
}
}
}

extension ResolvableTarget where Self == ResolvableTargets.DNS {
/// Creates a new resolvable DNS target.
/// - Parameters:
/// - host: The host address to resolve.
/// - port: The port to use for each resolved address.
/// - Returns: A ``ResolvableTarget``.
public static func dns(host: String, port: Int = 443) -> Self {
return Self(host: host, port: port)
}
}

@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
extension NameResolvers {
/// A ``NameResolverFactory`` for ``ResolvableTargets/DNS`` targets.
public struct DNS: NameResolverFactory {
public typealias Target = ResolvableTargets.DNS

/// Create a new DNS name resolver factory.
public init() {}

public func resolver(for target: Target) -> NameResolver {
let resolver = Self.Resolver(target: target)
return NameResolver(names: RPCAsyncSequence(wrapping: resolver), updateMode: .pull)
}
}
}

@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
extension NameResolvers.DNS {
struct Resolver: Sendable {
var target: ResolvableTargets.DNS

init(target: ResolvableTargets.DNS) {
self.target = target
}

func resolve(
isolation actor: isolated (any Actor)? = nil
) async throws -> NameResolutionResult {
let addresses: [SocketAddress]

do {
addresses = try await DNSResolver.resolve(host: self.target.host, port: self.target.port)
} catch let error as CancellationError {
throw error
} catch {
throw RPCError(
code: .internalError,
message: "Couldn't resolve address for \(self.target.host):\(self.target.port)",
cause: error
)
}

return NameResolutionResult(endpoints: [Endpoint(addresses: addresses)], serviceConfig: nil)
}
}
}

@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
extension NameResolvers.DNS.Resolver: AsyncSequence {
typealias Element = NameResolutionResult

func makeAsyncIterator() -> AsyncIterator {
return AsyncIterator(resolver: self)
}

struct AsyncIterator: AsyncIteratorProtocol {
typealias Element = NameResolutionResult

private let resolver: NameResolvers.DNS.Resolver

init(resolver: NameResolvers.DNS.Resolver) {
self.resolver = resolver
}

func next() async throws -> NameResolutionResult? {
return try await self.next(isolation: nil)
}

func next(
isolation actor: isolated (any Actor)?
) async throws(any Error) -> NameResolutionResult? {
return try await self.resolver.resolve(isolation: actor)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@
@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *)
public struct NameResolverRegistry {
private enum Factory {
case dns(NameResolvers.DNS)
case ipv4(NameResolvers.IPv4)
case ipv6(NameResolvers.IPv6)
case unix(NameResolvers.UnixDomainSocket)
case vsock(NameResolvers.VirtualSocket)
case other(any NameResolverFactory)

init(_ factory: some NameResolverFactory) {
if let ipv4 = factory as? NameResolvers.IPv4 {
if let dns = factory as? NameResolvers.DNS {
self = .dns(dns)
} else if let ipv4 = factory as? NameResolvers.IPv4 {
self = .ipv4(ipv4)
} else if let ipv6 = factory as? NameResolvers.IPv6 {
self = .ipv6(ipv6)
Expand All @@ -62,6 +65,8 @@ public struct NameResolverRegistry {

func makeResolverIfCompatible<Target: ResolvableTarget>(_ target: Target) -> NameResolver? {
switch self {
case .dns(let factory):
return factory.makeResolverIfCompatible(target)
case .ipv4(let factory):
return factory.makeResolverIfCompatible(target)
case .ipv6(let factory):
Expand All @@ -77,6 +82,8 @@ public struct NameResolverRegistry {

func hasTarget<Target: ResolvableTarget>(_ target: Target) -> Bool {
switch self {
case .dns(let factory):
return factory.isCompatible(withTarget: target)
case .ipv4(let factory):
return factory.isCompatible(withTarget: target)
case .ipv6(let factory):
Expand All @@ -92,6 +99,8 @@ public struct NameResolverRegistry {

func `is`<Factory: NameResolverFactory>(ofType factoryType: Factory.Type) -> Bool {
switch self {
case .dns:
return NameResolvers.DNS.self == factoryType
case .ipv4:
return NameResolvers.IPv4.self == factoryType
case .ipv6:
Expand All @@ -116,12 +125,14 @@ public struct NameResolverRegistry {
/// Returns a new name resolver registry with the default factories registered.
///
/// The default resolvers include:
/// - ``NameResolvers/DNS``,
/// - ``NameResolvers/IPv4``,
/// - ``NameResolvers/IPv6``,
/// - ``NameResolvers/UnixDomainSocket``,
/// - ``NameResolvers/VirtualSocket``.
public static var defaults: Self {
var resolvers = NameResolverRegistry()
resolvers.registerFactory(NameResolvers.DNS())
resolvers.registerFactory(NameResolvers.IPv4())
resolvers.registerFactory(NameResolvers.IPv6())
resolvers.registerFactory(NameResolvers.UnixDomainSocket())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,12 @@ final class NameResolverRegistryTests: XCTestCase {

func testDefaultResolvers() {
let resolvers = NameResolverRegistry.defaults
XCTAssert(resolvers.containsFactory(ofType: NameResolvers.DNS.self))
XCTAssert(resolvers.containsFactory(ofType: NameResolvers.IPv4.self))
XCTAssert(resolvers.containsFactory(ofType: NameResolvers.IPv6.self))
XCTAssert(resolvers.containsFactory(ofType: NameResolvers.UnixDomainSocket.self))
XCTAssert(resolvers.containsFactory(ofType: NameResolvers.VirtualSocket.self))
XCTAssertEqual(resolvers.count, 4)
XCTAssertEqual(resolvers.count, 5)
}

func testMakeResolver() {
Expand Down Expand Up @@ -167,6 +168,28 @@ final class NameResolverRegistryTests: XCTestCase {
}
}

func testDNSResolverForIPv4() async throws {
let factory = NameResolvers.DNS()
let resolver = factory.resolver(for: .dns(host: "127.0.0.1", port: 1234))
XCTAssertEqual(resolver.updateMode, .pull)

var iterator = resolver.names.makeAsyncIterator()
let result = try await XCTUnwrapAsync { try await iterator.next() }
XCTAssertEqual(result.endpoints, [Endpoint(.ipv4(host: "127.0.0.1", port: 1234))])
XCTAssertNil(result.serviceConfig)
}

func testDNSResolverForIPv6() async throws {
let factory = NameResolvers.DNS()
let resolver = factory.resolver(for: .dns(host: "::1", port: 1234))
XCTAssertEqual(resolver.updateMode, .pull)

var iterator = resolver.names.makeAsyncIterator()
let result = try await XCTUnwrapAsync { try await iterator.next() }
XCTAssertEqual(result.endpoints, [Endpoint(.ipv6(host: "::1", port: 1234))])
XCTAssertNil(result.serviceConfig)
}

func testIPv4ResolverForSingleHost() async throws {
let factory = NameResolvers.IPv4()
let resolver = factory.resolver(for: .ipv4(host: "foo", port: 1234))
Expand Down

0 comments on commit ba557f2

Please sign in to comment.