Skip to content

Commit

Permalink
Support both openssl and rustls
Browse files Browse the repository at this point in the history
  • Loading branch information
rustworthy committed Apr 21, 2024
1 parent dd3a989 commit 5a17498
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tls.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:
- name: Run tests
env:
FAKTORY_URL_SECURE: tcp://localhost:17419
run: cargo test --locked --features tls --test tls
run: cargo test --locked --features openssl,rustls --test tls
82 changes: 82 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ exclude = [".github", "docker", ".gitignore", "Makefile"]

[features]
default = []
tls = ["dep:tokio-native-tls", "dep:pin-project"]
tls = ["dep:pin-project"]
openssl = ["tls", "dep:tokio-native-tls"]
rustls = ["tls", "dep:tokio-rustls"]
binaries = ["dep:clap", "tokio/macros"]
ent = []

Expand Down Expand Up @@ -44,6 +46,7 @@ tokio = { version = "1.35.1", features = [
"time",
] }
tokio-native-tls = { version = "0.3.1", optional = true }
tokio-rustls = { version = "0.25.0", optional = true }
url = "2"

[dev-dependencies]
Expand All @@ -53,7 +56,7 @@ x509-parser = "0.15.1"

# to make -Zminimal-versions work
[target.'cfg(any())'.dependencies]
openssl = { version = "0.10.60", optional = true }
openssl-crate = { package = "openssl", version = "0.10.60", optional = true }
native-tls = { version = "0.2.4", optional = true }
num-bigint = "0.4.2"
oid-registry = "0.6.1"
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ test/e2e:
.PHONY: test/e2e/tls
test/e2e/tls:
FAKTORY_URL_SECURE=tcp://${FAKTORY_HOST}:${FAKTORY_PORT_SECURE} \
cargo test --locked --features tls --test tls
cargo test --locked --features openssl,rustls --test tls

.PHONY: test/load
test/load:
Expand Down
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,4 @@ pub mod ent {
mod tls;

#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub use tls::TlsStream;
pub use tls::*;
12 changes: 12 additions & 0 deletions src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#[cfg(feature = "openssl")]
#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
/// Namespace for OpenSSL-powered [`TlsStream`](crate::openssl::TlsStream).
///
/// The underlying crate (`native-tls`) will use _SChannel_ on Windows,
/// _SecureTransport_ on OSX, and _OpenSSL_ on other platforms.
pub mod openssl;

#[cfg(feature = "rustls")]
#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
/// Namespace for Rustls-powered [`TlsStream`](crate::rustls::TlsStream).
pub mod rustls;
5 changes: 3 additions & 2 deletions src/tls.rs → src/tls/openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use tokio_native_tls::{native_tls::TlsConnector, TlsConnector as AsyncTlsConnect
///
/// ```no_run
/// # tokio_test::block_on(async {
/// use faktory::{Client, TlsStream};
/// use faktory::Client;
/// use faktory::openssl::TlsStream;
/// let tls = TlsStream::connect(None).await.unwrap();
/// let cl = Client::connect_with(tls, None).await.unwrap();
/// # drop(cl);
Expand All @@ -31,7 +32,7 @@ pub struct TlsStream<S> {
connector: AsyncTlsConnector,
hostname: String,
#[pin]
stream: NativeTlsStream<S>,
pub(crate) stream: NativeTlsStream<S>,
}

impl TlsStream<TokioTcpStream> {
Expand Down
182 changes: 182 additions & 0 deletions src/tls/rustls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#[cfg(doc)]
use crate::{Client, WorkerBuilder};

use crate::{proto::utils, Error, Reconnect};
use std::fmt::Debug;
use std::io;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_rustls::client::TlsStream as RustlsStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;

/// A reconnectable stream encrypted with TLS.
///
/// This can be used as an argument to [`WorkerBuilder::connect_with`] and [`Client::connect_with`] to
/// connect to a TLS-secured Faktory server.
///
/// # Examples
///
/// ```no_run
/// # tokio_test::block_on(async {
/// use faktory::Client;
/// use faktory::rustls::TlsStream;
/// let tls = TlsStream::connect(None).await.unwrap();
/// let cl = Client::connect_with(tls, None).await.unwrap();
/// # drop(cl);
/// # });
/// ```
///
#[pin_project::pin_project]

Check warning on line 32 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L32

Added line #L32 was not covered by tests
pub struct TlsStream<S> {
connector: TlsConnector,
hostname: &'static str,
#[pin]
pub(crate) stream: RustlsStream<S>,
}

impl TlsStream<TokioTcpStream> {
/// Create a new TLS connection over TCP.
///
/// If `url` is not given, will use the standard Faktory environment variables. Specifically,
/// `FAKTORY_PROVIDER` is read to get the name of the environment variable to get the address
/// from (defaults to `FAKTORY_URL`), and then that environment variable is read to get the
/// server address. If the latter environment variable is not defined, the connection will be
/// made to
///
/// ```text
/// tcp://localhost:7419
/// ```
///
/// If `url` is given, but does not specify a port, it defaults to 7419.
///
/// Internally creates a `ClientConfig` with an empty root certificates store and no client
/// authentication. Use [`with_client_config`](TlsStream::with_client_config)
/// or [`with_connector`](TlsStream::with_connector) for customized
/// `ClientConfig` and `TlsConnector` accordingly.
pub async fn connect(url: Option<&str>) -> Result<Self, Error> {
let conf = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth();
let con = TlsConnector::from(Arc::new(conf));
TlsStream::with_connector(con, url).await
}

Check warning on line 65 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L59-L65

Added lines #L59 - L65 were not covered by tests

/// Create a new asynchronous TLS connection over TCP using a non-default TLS configuration.
///
/// See `connect` for details about the `url` parameter.
pub async fn with_client_config(conf: ClientConfig, url: Option<&str>) -> Result<Self, Error> {
let con = TlsConnector::from(Arc::new(conf));
TlsStream::with_connector(con, url).await
}

Check warning on line 73 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L70-L73

Added lines #L70 - L73 were not covered by tests

/// Create a new asynchronous TLS connection over TCP using a connector with a non-default TLS configuration.
///
/// See `connect` for details about the `url` parameter.
pub async fn with_connector(connector: TlsConnector, url: Option<&str>) -> Result<Self, Error> {
let url = match url {
Some(url) => utils::url_parse(url),
None => utils::url_parse(&utils::get_env_url()),
}?;
let hostname = utils::host_from_url(&url);
let tcp_stream = TokioTcpStream::connect(&hostname).await?;
let hostname: &'static str = url.host_str().unwrap().to_string().leak();
Ok(TlsStream::new(tcp_stream, connector, hostname).await?)
}

Check warning on line 87 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L78-L87

Added lines #L78 - L87 were not covered by tests
}

impl<S> TlsStream<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + Reconnect + Debug + 'static,
{
/// Create a new asynchronous TLS connection on an existing stream.
///
/// Internally creates a `ClientConfig` with an empty root certificates store and no client
/// authentication. Use [`new`](TlsStream::new) for a customized `TlsConnector`.
pub async fn default(stream: S, hostname: &'static str) -> io::Result<Self> {
let conf = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth();

Self::new(stream, TlsConnector::from(Arc::new(conf)), hostname).await
}

Check warning on line 104 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L98-L104

Added lines #L98 - L104 were not covered by tests

/// Create a new asynchronous TLS connection on an existing stream with a non-default TLS configuration.
pub async fn new(
stream: S,
connector: TlsConnector,
hostname: &'static str,
) -> io::Result<Self> {
// let hostname: &'static str = hostname.to_string().leak();
let domain = hostname.try_into().expect("a valid DNS name or IP address");
let tls_stream = connector
.connect(domain, stream)
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionAborted, e))?;
Ok(TlsStream {
connector,
hostname,
stream: tls_stream,
})
}

Check warning on line 123 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L107-L123

Added lines #L107 - L123 were not covered by tests
}

#[async_trait::async_trait]
impl<S> Reconnect for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + Reconnect + Debug + 'static + Sync,
{
async fn reconnect(&mut self) -> io::Result<Self> {
let stream = self.stream.get_mut().0.reconnect().await?;
Self::new(stream, self.connector.clone(), &self.hostname).await
}

Check warning on line 134 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L131-L134

Added lines #L131 - L134 were not covered by tests
}

impl<S> Deref for TlsStream<S> {
type Target = RustlsStream<S>;
fn deref(&self) -> &Self::Target {
&self.stream
}

Check warning on line 141 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L139-L141

Added lines #L139 - L141 were not covered by tests
}

impl<S> DerefMut for TlsStream<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}

Check warning on line 147 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L145-L147

Added lines #L145 - L147 were not covered by tests
}

impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<S> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
self.project().stream.poll_read(cx, buf)
}

Check warning on line 157 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L151-L157

Added lines #L151 - L157 were not covered by tests
}

impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
self.project().stream.poll_write(cx, buf)
}

Check warning on line 167 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L161-L167

Added lines #L161 - L167 were not covered by tests

fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx)
}

Check warning on line 174 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L169-L174

Added lines #L169 - L174 were not covered by tests

fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx)
}

Check warning on line 181 in src/tls/rustls.rs

View check run for this annotation

Codecov / codecov/patch

src/tls/rustls.rs#L176-L181

Added lines #L176 - L181 were not covered by tests
}
Loading

0 comments on commit 5a17498

Please sign in to comment.