Skip to content

Commit

Permalink
work continues
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Apr 16, 2024
1 parent 5b2cebf commit da17b1e
Show file tree
Hide file tree
Showing 14 changed files with 584 additions and 151 deletions.
13 changes: 10 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ arc-swap = "1"
async-scoped = { version = "0.8", features = ["use-tokio"] }
async-trait = "0.1"
axum = "0.7"
axum-server = { version = "0.6", features = ["tls-rustls"] }
candid = "0.10"
clap = { version = "4.5", features = ["derive", "string"] }
clap_derive = "4.5"
ctrlc = { version = "3.4", features = [ "termination" ] }
fqdn = "0.3"
futures = "0.3"
hickory-resolver = {version = "0.24", features = ["dns-over-rustls", "dns-over-https-rustls", "dnssec-ring"] }
futures-util = "0.3"
hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots", "dnssec-ring"] }
http = "1.1"
humantime = "2.1"
hyper = "1.2"
hyper-util = "0.1"
jemallocator = "0.5"
jemalloc-ctl = "0.5"
lazy_static = "1.4"
Expand Down Expand Up @@ -48,9 +52,10 @@ webpki-roots = "0.26"
serde = "1.0"
serde_json = "1.0"
strum = "0.26"
strum_macros = "0.26"
thiserror = "1.0"
tokio = "1.36"
tokio-util = "0.7"
tokio = { version = "1.36", features = [ "full" ] }
tokio-util = {version = "0.7", features = [ "full" ] }
tokio-rustls = "0.26"
tower = "0.4"
tower_governor = "0.3"
Expand All @@ -59,6 +64,7 @@ tower-http = { version = "0.5", features = [
"request-id",
"compression-full",
] }
tower-service = "0.3"
tracing = "0.1"
tracing-core = "0.1"
tracing-serde = "0.1"
Expand All @@ -68,3 +74,4 @@ x509-parser = "0.16"

[dev-dependencies]
criterion = { version = "0.5", features = ["async_tokio"] }
tempfile = "3.10"
44 changes: 34 additions & 10 deletions src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
use std::{net::IpAddr, time::Duration};
use std::{
net::{IpAddr, SocketAddr},
time::Duration,
};

use anyhow::{anyhow, Error};
use clap::{Args, Parser};
use hickory_resolver::config::{Protocol, CLOUDFLARE_IPS};
use hickory_resolver::config::CLOUDFLARE_IPS;
use humantime::parse_duration;

use crate::{
core::{AUTHOR_NAME, SERVICE_NAME},
http::dns,
};

// Clap does not support prefixes due to macro limitations
// so we have to add prefixes manually (long = "...")

#[derive(Parser)]
#[clap(name = SERVICE_NAME)]
#[clap(author = AUTHOR_NAME)]
pub struct Cli {
#[command(flatten, next_help_heading = "HTTP Client")]
pub http_client: HttpClient,

#[command(flatten, next_help_heading = "DNS")]
#[command(flatten, next_help_heading = "DNS Resolver")]
pub dns: Dns,

#[command(flatten, next_help_heading = "HTTP Server")]
pub http_server: HttpServer,
}

// Clap does not support prefixes due to macro limitations
// so we have to add them manually (long = "...")

#[derive(Args)]
pub struct HttpClient {
/// Timeout for HTTP connection phase
Expand All @@ -47,17 +52,36 @@ pub struct HttpClient {
pub http2_keepalive_timeout: Duration,
}

#[derive(Args, Clone, Debug)]
#[derive(Args)]
pub struct Dns {
/// List of DNS servers to use
#[clap(long = "dns-servers", default_values_t = CLOUDFLARE_IPS)]
pub servers: Vec<IpAddr>,

/// DNS protocol to use (udp/tcp/tls/https)
/// DNS protocol to use (clear/tls/https)
#[clap(long = "dns-protocol", default_value = "tls")]
pub protocol: dns::Protocol,

/// TLS name for DNS-over-TLS and DNS-over-HTTPS protocols
/// TLS name to expect for TLS and HTTPS protocols (e.g. "dns.google" or "cloudflare-dns.com")
#[clap(long = "dns-tls-name", default_value = "cloudflare-dns.com")]
pub tls_name: String,

/// Cache size for the resolver (in number of DNS records)
#[clap(long = "dns-cache-size", default_value = "2048")]
pub cache_size: usize,
}

#[derive(Args)]
pub struct HttpServer {
/// Where to listen for HTTP
#[clap(long = "listen-http", default_value = "[::1]:8080")]
pub http: SocketAddr,

/// Where to listen for HTTPS
#[clap(long = "listen-https", default_value = "[::1]:8443")]
pub https: SocketAddr,

/// Backlog of incoming connections to set on the listening socket.
#[clap(long, default_value = "8192")]
pub backlog: u32,
}
39 changes: 37 additions & 2 deletions src/core.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
use anyhow::Error;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::{error, warn};

use crate::cli::Cli;
use crate::{cli::Cli, http::server::Server};

pub const SERVICE_NAME: &str = "ic_gateway";
pub const AUTHOR_NAME: &str = "Boundary Node Team <[email protected]>";

pub async fn main(_cli: Cli) -> Result<(), Error> {
pub async fn main(cli: Cli) -> Result<(), Error> {
let token = CancellationToken::new();
let tracker = TaskTracker::new();

// Handle SIGTERM/SIGHUP and Ctrl+C
// Cancelling a token cancels all of its clones too, except the ones from .child_token()
let handler_token = token.clone();
ctrlc::set_handler(move || handler_token.cancel())?;

let router = axum::Router::new().route("/", axum::routing::get(|| async { "Hello, World!" }));

let http_server = Server::new(
cli.http_server.http,
cli.http_server.backlog,
router,
token.child_token(),
None,
);

tracker.spawn(async move {
match http_server.start().await {
Ok(()) => {}
Err(e) => {
error!("Unable to start server: {e}");
}
};
});

warn!("Service is running, waiting for the shutdown signal");
token.cancelled().await;
warn!("Shutdown signal received, cleaning up");
tracker.close();
tracker.wait().await;

Ok(())
}
45 changes: 45 additions & 0 deletions src/http/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::sync::Arc;

use async_trait::async_trait;
use mockall::automock;

use crate::{cli, core::SERVICE_NAME, http::dns::Resolver, tls::prepare_rustls_client_config};

#[automock]
#[async_trait]
pub trait Client: Send + Sync {
async fn execute(&self, req: reqwest::Request) -> Result<reqwest::Response, reqwest::Error>;
}

pub struct ReqwestClient(reqwest::Client);

impl ReqwestClient {
pub fn new(cli: &cli::Cli) -> Result<Self, anyhow::Error> {
let http = &cli.http_client;

let client = reqwest::Client::builder()
.use_preconfigured_tls(prepare_rustls_client_config())
.dns_resolver(Arc::new(Resolver::new(&cli.dns)))
.connect_timeout(http.timeout_connect)
.timeout(http.timeout)
.tcp_nodelay(true)
.tcp_keepalive(Some(http.tcp_keepalive))
.http2_keep_alive_interval(Some(http.http2_keepalive))
.http2_keep_alive_timeout(http.http2_keepalive_timeout)
.http2_keep_alive_while_idle(true)
.http2_adaptive_window(true)
.user_agent(SERVICE_NAME)
.redirect(reqwest::redirect::Policy::none())
.no_proxy()
.build()?;

Ok(Self(client))
}
}

#[async_trait]
impl Client for ReqwestClient {
async fn execute(&self, req: reqwest::Request) -> Result<reqwest::Response, reqwest::Error> {
self.0.execute(req).await
}
}
89 changes: 32 additions & 57 deletions src/http/dns.rs
Original file line number Diff line number Diff line change
@@ -1,57 +1,53 @@
use std::{net::SocketAddr, str::FromStr, sync::Arc};
use std::{net::SocketAddr, sync::Arc};

use anyhow::{anyhow, Error};
use hickory_resolver::{
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
lookup_ip::LookupIpIntoIter,
TokioAsyncResolver,
};
use once_cell::sync::OnceCell;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use strum_macros::EnumString;

use crate::cli::Dns;

#[derive(Clone, Debug)]
#[derive(Clone, Debug, EnumString)]
#[strum(serialize_all = "snake_case")]
pub enum Protocol {
Clear,
Tls,
Https,
}

impl FromStr for Protocol {
type Err = Error;
#[derive(Debug, Clone)]
pub struct Resolver(Arc<TokioAsyncResolver>);

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"clear" => Protocol::Clear,
"tls" => Protocol::Tls,
"https" => Protocol::Https,
_ => return Err(anyhow!("Unknown DNS protocol")),
})
// new() must be called in Tokio context
impl Resolver {
pub fn new(cli: &Dns) -> Self {
let name_servers = match cli.protocol {
Protocol::Clear => NameServerConfigGroup::from_ips_clear(&cli.servers, 53, true),
Protocol::Tls => {
NameServerConfigGroup::from_ips_tls(&cli.servers, 853, cli.tls_name.clone(), true)
}
Protocol::Https => {
NameServerConfigGroup::from_ips_https(&cli.servers, 443, cli.tls_name.clone(), true)
}
};

let cfg = ResolverConfig::from_parts(None, vec![], name_servers);

let mut opts = ResolverOpts::default();
opts.rotate = true;
opts.cache_size = cli.cache_size;
opts.use_hosts_file = false;
opts.preserve_intermediates = false;
opts.try_tcp_on_error = true;

let resolver = TokioAsyncResolver::tokio(cfg, opts);
Self(Arc::new(resolver))
}
}

pub fn prepare_dns_resolver(cli: Dns) -> TokioAsyncResolver {
let name_servers = match cli.protocol {
Protocol::Clear => NameServerConfigGroup::from_ips_clear(&cli.servers, 53, true),
Protocol::Tls => NameServerConfigGroup::from_ips_tls(&cli.servers, 853, cli.tls_name, true),
Protocol::Https => {
NameServerConfigGroup::from_ips_https(&cli.servers, 443, cli.tls_name, true)
}
};

let cfg = ResolverConfig::from_parts(None, vec![], name_servers);

let mut opts = ResolverOpts::default();
opts.rotate = true;
opts.cache_size = 2048;
opts.use_hosts_file = false;
opts.preserve_intermediates = false;
opts.try_tcp_on_error = true;

TokioAsyncResolver::tokio(cfg, opts)
}

struct SocketAddrs {
iter: LookupIpIntoIter,
}
Expand All @@ -64,34 +60,13 @@ impl Iterator for SocketAddrs {
}
}

#[derive(Debug, Clone)]
pub struct DnsResolver {
// Constructor is called most probably not in the Tokio context
// so we delay creation of the resolver using once_cell
state: Arc<OnceCell<TokioAsyncResolver>>,
cli: Dns,
}

impl DnsResolver {
pub fn new(cli: &Dns) -> Self {
Self {
state: Arc::new(OnceCell::new()),
cli: cli.clone(),
}
}
}

// Implement resolving for Reqwest using Hickory
impl Resolve for DnsResolver {
impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving {
let resolver = self.clone();

Box::pin(async move {
let resolver = resolver
.state
.get_or_init(|| prepare_dns_resolver(resolver.cli));

let lookup = resolver.lookup_ip(name.as_str()).await?;
let lookup = resolver.0.lookup_ip(name.as_str()).await?;
let addrs: Addrs = Box::new(SocketAddrs {
iter: lookup.into_iter(),
});
Expand Down
49 changes: 2 additions & 47 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,3 @@
pub mod client;
pub mod dns;

use std::sync::Arc;

use async_trait::async_trait;
use mockall::automock;
use reqwest::{Client, Error, Request, Response};

use crate::{cli, core::SERVICE_NAME, http::dns::DnsResolver, tls::prepare_rustls_client_config};

#[automock]
#[async_trait]
pub trait HttpClient: Send + Sync {
async fn execute(&self, req: Request) -> Result<Response, Error>;
}

pub struct ReqwestClient(Client);

impl ReqwestClient {
pub fn new(cli: &cli::Cli) -> Result<Self, anyhow::Error> {
let http = &cli.http_client;

let client = Client::builder()
.use_preconfigured_tls(prepare_rustls_client_config())
.dns_resolver(Arc::new(DnsResolver::new(&cli.dns)))
.connect_timeout(http.timeout_connect)
.timeout(http.timeout)
.tcp_nodelay(true)
.tcp_keepalive(Some(http.tcp_keepalive))
.http2_keep_alive_interval(Some(http.http2_keepalive))
.http2_keep_alive_timeout(http.http2_keepalive_timeout)
.http2_keep_alive_while_idle(true)
.http2_adaptive_window(true)
.user_agent(SERVICE_NAME)
.redirect(reqwest::redirect::Policy::none())
.no_proxy()
.build()?;

Ok(Self(client))
}
}

#[async_trait]
impl HttpClient for ReqwestClient {
async fn execute(&self, req: Request) -> Result<Response, Error> {
self.0.execute(req).await
}
}
pub mod server;
Loading

0 comments on commit da17b1e

Please sign in to comment.