Skip to content

Allow DNS override #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ let package = Package(
resources: [
.copy("Resources/self_signed_cert.pem"),
.copy("Resources/self_signed_key.pem"),
.copy("Resources/example.com.cert.pem"),
.copy("Resources/example.com.private-key.pem"),
]
),
]
Expand Down
2 changes: 2 additions & 0 deletions Package@swift-5.5.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ let package = Package(
resources: [
.copy("Resources/self_signed_cert.pem"),
.copy("Resources/self_signed_key.pem"),
.copy("Resources/example.com.cert.pem"),
.copy("Resources/example.com.private-key.pem"),
]
),
]
Expand Down
4 changes: 2 additions & 2 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ extension HTTPClient {

// this loop is there to follow potential redirects
while true {
let preparedRequest = try HTTPClientRequest.Prepared(currentRequest)
let preparedRequest = try HTTPClientRequest.Prepared(currentRequest, dnsOverride: configuration.dnsOverride)
let response = try await executeCancellable(preparedRequest, deadline: deadline, logger: logger)

guard var redirectState = currentRedirectState else {
Expand Down Expand Up @@ -131,7 +131,7 @@ extension HTTPClient {
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<HTTPClientResponse, Swift.Error>) -> Void in
let transaction = Transaction(
request: request,
requestOptions: .init(idleReadTimeout: nil),
requestOptions: .fromClientConfiguration(self.configuration),
logger: logger,
connectionDeadline: .now() + (self.configuration.timeout.connectionCreationTimeout),
preferredEventLoop: eventLoop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ extension HTTPClientRequest {

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientRequest.Prepared {
init(_ request: HTTPClientRequest) throws {
init(_ request: HTTPClientRequest, dnsOverride: [String: String] = [:]) throws {
guard let url = URL(string: request.url) else {
throw HTTPClientError.invalidURL
}
Expand All @@ -58,7 +58,7 @@ extension HTTPClientRequest.Prepared {

self.init(
url: url,
poolKey: .init(url: deconstructedURL, tlsConfiguration: nil),
poolKey: .init(url: deconstructedURL, tlsConfiguration: nil, dnsOverride: dnsOverride),
requestFramingMetadata: metadata,
head: .init(
version: .http1_1,
Expand Down
54 changes: 47 additions & 7 deletions Sources/AsyncHTTPClient/ConnectionPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@

import NIOSSL

#if canImport(Darwin)
import Darwin.C
#elseif os(Linux) || os(FreeBSD) || os(Android)
import Glibc
#else
#error("unsupported target operating system")
#endif

extension String {
var isIPAddress: Bool {
var ipv4Address = in_addr()
var ipv6Address = in6_addr()
return self.withCString { host in
inet_pton(AF_INET, host, &ipv4Address) == 1 ||
inet_pton(AF_INET6, host, &ipv6Address) == 1
}
}
}

enum ConnectionPool {
/// Used by the `ConnectionPool` to index its `HTTP1ConnectionProvider`s
///
Expand All @@ -24,15 +43,18 @@ enum ConnectionPool {
var scheme: Scheme
var connectionTarget: ConnectionTarget
private var tlsConfiguration: BestEffortHashableTLSConfiguration?
var serverNameIndicatorOverride: String?

init(
scheme: Scheme,
connectionTarget: ConnectionTarget,
tlsConfiguration: BestEffortHashableTLSConfiguration? = nil
tlsConfiguration: BestEffortHashableTLSConfiguration? = nil,
serverNameIndicatorOverride: String?
) {
self.scheme = scheme
self.connectionTarget = connectionTarget
self.tlsConfiguration = tlsConfiguration
self.serverNameIndicatorOverride = serverNameIndicatorOverride
}

var description: String {
Expand All @@ -48,26 +70,44 @@ enum ConnectionPool {
case .unixSocket(let socketPath):
hostDescription = socketPath
}
return "\(self.scheme)://\(hostDescription) TLS-hash: \(hash)"
return "\(self.scheme)://\(hostDescription)\(self.serverNameIndicatorOverride.map { " SNI: \($0)" } ?? "") TLS-hash: \(hash) "
}
}
}

extension DeconstructedURL {
func applyDNSOverride(_ dnsOverride: [String: String]) -> (ConnectionTarget, serverNameIndicatorOverride: String?) {
guard
let originalHost = self.connectionTarget.host,
let hostOverride = dnsOverride[originalHost]
else {
return (self.connectionTarget, nil)
}
return (
.init(remoteHost: hostOverride, port: self.connectionTarget.port ?? self.scheme.defaultPort),
serverNameIndicatorOverride: originalHost.isIPAddress ? nil : originalHost
)
}
}

extension ConnectionPool.Key {
init(url: DeconstructedURL, tlsConfiguration: TLSConfiguration?) {
init(url: DeconstructedURL, tlsConfiguration: TLSConfiguration?, dnsOverride: [String: String]) {
let (connectionTarget, serverNameIndicatorOverride) = url.applyDNSOverride(dnsOverride)
self.init(
scheme: url.scheme,
connectionTarget: url.connectionTarget,
connectionTarget: connectionTarget,
tlsConfiguration: tlsConfiguration.map {
BestEffortHashableTLSConfiguration(wrapping: $0)
}
},
serverNameIndicatorOverride: serverNameIndicatorOverride
)
}

init(_ request: HTTPClient.Request) {
init(_ request: HTTPClient.Request, dnsOverride: [String: String] = [:]) {
self.init(
url: request.deconstructedURL,
tlsConfiguration: request.tlsConfiguration
tlsConfiguration: request.tlsConfiguration,
dnsOverride: dnsOverride
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ extension HTTPConnectionPool.ConnectionFactory {
}
let tlsEventHandler = TLSEventsHandler(deadline: deadline)

let sslServerHostname = self.key.connectionTarget.sslServerHostname
let sslServerHostname = self.key.serverNameIndicator
let sslContextFuture = self.sslContextCache.sslContext(
tlsConfiguration: tlsConfig,
eventLoop: channel.eventLoop,
Expand Down Expand Up @@ -409,7 +409,7 @@ extension HTTPConnectionPool.ConnectionFactory {
#if canImport(Network)
if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) {
// create NIOClientTCPBootstrap with NIOTS TLS provider
let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions(on: eventLoop).map {
let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions(on: eventLoop, serverNameIndicatorOverride: key.serverNameIndicatorOverride).map {
options -> NIOClientTCPBootstrapProtocol in

tsBootstrap
Expand All @@ -434,7 +434,6 @@ extension HTTPConnectionPool.ConnectionFactory {
}
#endif

let sslServerHostname = self.key.connectionTarget.sslServerHostname
let sslContextFuture = sslContextCache.sslContext(
tlsConfiguration: tlsConfig,
eventLoop: eventLoop,
Expand All @@ -449,7 +448,7 @@ extension HTTPConnectionPool.ConnectionFactory {
let sync = channel.pipeline.syncOperations
let sslHandler = try NIOSSLClientHandler(
context: sslContext,
serverHostname: sslServerHostname
serverHostname: self.key.serverNameIndicator
)
let tlsEventHandler = TLSEventsHandler(deadline: deadline)

Expand Down Expand Up @@ -488,6 +487,12 @@ extension Scheme {
}
}

extension ConnectionPool.Key {
var serverNameIndicator: String? {
serverNameIndicatorOverride ?? connectionTarget.sslServerHostname
}
}

extension ConnectionTarget {
fileprivate var sslServerHostname: String? {
switch self {
Expand Down
8 changes: 6 additions & 2 deletions Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ struct RequestOptions {
/// The maximal `TimeAmount` that is allowed to pass between `channelRead`s from the Channel.
var idleReadTimeout: TimeAmount?

init(idleReadTimeout: TimeAmount?) {
var dnsOverride: [String: String]

init(idleReadTimeout: TimeAmount?, dnsOverride: [String: String]) {
self.idleReadTimeout = idleReadTimeout
self.dnsOverride = dnsOverride
}
}

extension RequestOptions {
static func fromClientConfiguration(_ configuration: HTTPClient.Configuration) -> Self {
RequestOptions(
idleReadTimeout: configuration.timeout.read
idleReadTimeout: configuration.timeout.read,
dnsOverride: configuration.dnsOverride
)
}
}
11 changes: 11 additions & 0 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,17 @@ public class HTTPClient {
public struct Configuration {
/// TLS configuration, defaults to `TLSConfiguration.makeClientConfiguration()`.
public var tlsConfiguration: Optional<TLSConfiguration>

/// Sometimes it can be useful to connect to one host e.g. `x.example.com` but
/// request and validate the certificate chain as if we would connect to `y.example.com`.
/// ``dnsOverride`` allows to do just that by mapping host names which we will request and validate the certificate chain, to a different
/// host name which will be used to actually connect to.
///
/// **Example:** if ``dnsOverride`` is set to `["example.com": "localhost"]` and we execute a request with a
/// `url` of `https://example.com/`, the ``HTTPClient`` will actually open a connection to `localhost` instead of `example.com`.
/// ``HTTPClient`` will still request certificates from the server for `example.com` and validate them as if we would connect to `example.com`.
public var dnsOverride: [String: String] = [:]

/// Enables following 3xx redirects automatically.
///
/// Following redirects are supported:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ extension TLSConfiguration {
///
/// - Parameter eventLoop: EventLoop to wait for creation of options on
/// - Returns: Future holding NWProtocolTLS Options
func getNWProtocolTLSOptions(on eventLoop: EventLoop) -> EventLoopFuture<NWProtocolTLS.Options> {
func getNWProtocolTLSOptions(on eventLoop: EventLoop, serverNameIndicatorOverride: String?) -> EventLoopFuture<NWProtocolTLS.Options> {
let promise = eventLoop.makePromise(of: NWProtocolTLS.Options.self)
Self.tlsDispatchQueue.async {
do {
let options = try self.getNWProtocolTLSOptions()
let options = try self.getNWProtocolTLSOptions(serverNameIndicatorOverride: serverNameIndicatorOverride)
promise.succeed(options)
} catch {
promise.fail(error)
Expand All @@ -82,7 +82,7 @@ extension TLSConfiguration {
/// create NWProtocolTLS.Options for use with NIOTransportServices from the NIOSSL TLSConfiguration
///
/// - Returns: Equivalent NWProtocolTLS Options
func getNWProtocolTLSOptions() throws -> NWProtocolTLS.Options {
func getNWProtocolTLSOptions(serverNameIndicatorOverride: String?) throws -> NWProtocolTLS.Options {
let options = NWProtocolTLS.Options()

let useMTELGExplainer = """
Expand All @@ -92,6 +92,12 @@ extension TLSConfiguration {
platform networking stack).
"""

if let serverNameIndicatorOverride = serverNameIndicatorOverride {
serverNameIndicatorOverride.withCString { serverNameIndicatorOverride in
sec_protocol_options_set_tls_server_name(options.securityProtocolOptions, serverNameIndicatorOverride)
}
}

// minimum TLS protocol
if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) {
sec_protocol_options_set_min_tls_protocol_version(options.securityProtocolOptions, self.minimumTLSVersion.nwTLSProtocolVersion)
Expand Down
7 changes: 3 additions & 4 deletions Sources/AsyncHTTPClient/RequestBag.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
50
}

let poolKey: ConnectionPool.Key

let task: HTTPClient.Task<Delegate.Response>
var eventLoop: EventLoop {
self.task.eventLoop
Expand Down Expand Up @@ -63,6 +65,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
connectionDeadline: NIODeadline,
requestOptions: RequestOptions,
delegate: Delegate) throws {
self.poolKey = .init(request, dnsOverride: requestOptions.dnsOverride)
self.eventLoopPreference = eventLoopPreference
self.task = task
self.state = .init(redirectHandler: redirectHandler)
Expand Down Expand Up @@ -392,10 +395,6 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

extension RequestBag: HTTPSchedulableRequest {
var poolKey: ConnectionPool.Key {
ConnectionPool.Key(self.request)
}

var tlsConfiguration: TLSConfiguration? {
self.request.tlsConfiguration
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ extension AsyncAwaitEndToEndTests {
("testImmediateDeadline", testImmediateDeadline),
("testConnectTimeout", testConnectTimeout),
("testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded", testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded),
("testDnsOverride", testDnsOverride),
("testInvalidURL", testInvalidURL),
("testRedirectChangesHostHeader", testRedirectChangesHostHeader),
("testShutdown", testShutdown),
Expand Down
58 changes: 58 additions & 0 deletions Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,64 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
}
}

func testDnsOverride() {
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
XCTAsyncTest(timeout: 5) {
/// key + cert was created with the following code (depends on swift-certificates)
/// ```
/// let privateKey = P384.Signing.PrivateKey()
/// let name = try DistinguishedName {
/// OrganizationName("Self Signed")
/// CommonName("localhost")
/// }
/// let certificate = try Certificate(
/// version: .v3,
/// serialNumber: .init(),
/// publicKey: .init(privateKey.publicKey),
/// notValidBefore: Date(),
/// notValidAfter: Date() + .days(365),
/// issuer: name,
/// subject: name,
/// signatureAlgorithm: .ecdsaWithSHA384,
/// extensions: try .init {
/// SubjectAlternativeNames([.dnsName("example.com")])
/// ExtendedKeyUsage([.serverAuth])
/// },
/// issuerPrivateKey: .init(privateKey)
/// )
/// ```
let certPath = Bundle.module.path(forResource: "example.com.cert", ofType: "pem")!
let keyPath = Bundle.module.path(forResource: "example.com.private-key", ofType: "pem")!
let localhostCert = try NIOSSLCertificate.fromPEMFile(certPath)
let configuration = TLSConfiguration.makeServerConfiguration(
certificateChain: localhostCert.map { .certificate($0) },
privateKey: .file(keyPath)
)
let bin = HTTPBin(.http2(tlsConfiguration: configuration))
defer { XCTAssertNoThrow(try bin.shutdown()) }

var config = HTTPClient.Configuration()
.enableFastFailureModeForTesting()
var tlsConfig = TLSConfiguration.makeClientConfiguration()

tlsConfig.trustRoots = .certificates(localhostCert)
config.tlsConfiguration = tlsConfig
// this is the actual configuration under test
config.dnsOverride = ["example.com": "localhost"]

let localClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: config)
defer { XCTAssertNoThrow(try localClient.syncShutdown()) }
let request = HTTPClientRequest(url: "https://example.com:\(bin.port)/echohostheader")
let response = await XCTAssertNoThrowWithResult(try await localClient.execute(request, deadline: .now() + .seconds(2)))
XCTAssertEqual(response?.status, .ok)
XCTAssertEqual(response?.version, .http2)
var body = try await response?.body.collect(upTo: 1024)
let readableBytes = body?.readableBytes ?? 0
let responseInfo = try body?.readJSONDecodable(RequestInfo.self, length: readableBytes)
XCTAssertEqual(responseInfo?.data, "example.com\(bin.port == 443 ? "" : ":\(bin.port)")")
}
}

func testInvalidURL() {
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
XCTAsyncTest(timeout: 5) {
Expand Down
2 changes: 1 addition & 1 deletion Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class HTTPClientNIOTSTests: XCTestCase {
var tlsConfig = TLSConfiguration.makeClientConfiguration()
tlsConfig.trustRoots = .file("not/a/certificate")

XCTAssertThrowsError(try tlsConfig.getNWProtocolTLSOptions()) { error in
XCTAssertThrowsError(try tlsConfig.getNWProtocolTLSOptions(serverNameIndicatorOverride: nil)) { error in
switch error {
case let error as NIOSSL.NIOSSLError where error == .failedToLoadCertificate:
break
Expand Down
Loading