Skip to content

Commit

Permalink
Upgrade rustls-pemfile and tokio-rustls (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
gustavowd authored Mar 27, 2024
1 parent 6aad370 commit c259882
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 59 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ tokio = { version = "1.35.1", default-features = false, features = [
"time",
] }
tokio-serial = { version = "5.4.4", default-features = false }
rustls-pemfile = "1.0.3"
tokio-rustls = "0.24.1"
rustls-pemfile = "2.1.1"
tokio-rustls = "0.25.0"
pkcs8 = { version = "0.10.2", features = ["encryption", "pem", "std"] }
pem = "3.0.3"
webpki = "0.22.4"
pki-types = { package = "rustls-pki-types", version = "1" }
rustls = "0.22.3"

[features]
default = ["rtu", "tcp"]
Expand Down
59 changes: 31 additions & 28 deletions examples/tls-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,25 @@ use std::{
};

use pkcs8::der::Decode;
use pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls_pemfile::{certs, pkcs8_private_keys};
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, PrivateKey};
use tokio_rustls::TlsConnector;
use webpki::TrustAnchor;

fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())
fn load_certs(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
certs(&mut BufReader::new(File::open(path)?)).collect()
}

fn load_keys(path: &Path, password: Option<&str>) -> io::Result<Vec<PrivateKey>> {
fn load_keys(path: &Path, password: Option<&str>) -> io::Result<PrivateKeyDer<'static>> {
let expected_tag = match &password {
Some(_) => "ENCRYPTED PRIVATE KEY",
None => "PRIVATE KEY",
};

if expected_tag.eq("PRIVATE KEY") {
pkcs8_private_keys(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
.next()
.unwrap()
.map(Into::into)
} else {
let content = std::fs::read(path)?;
let mut iter = pem::parse_many(content)
Expand All @@ -56,9 +54,23 @@ fn load_keys(path: &Path, password: Option<&str>) -> io::Result<Vec<PrivateKey>>
io::Error::new(io::ErrorKind::InvalidData, err.to_string())
})?;
let key = decrypted.as_bytes().to_vec();
let key = rustls::PrivateKey(key);
let private_keys = vec![key];
io::Result::Ok(private_keys)
match rustls_pemfile::read_one_from_slice(&key)
.expect("cannot parse private key .pem file")
{
Some((rustls_pemfile::Item::Pkcs1Key(key), _keys)) => {
io::Result::Ok(key.into())
}
Some((rustls_pemfile::Item::Pkcs8Key(key), _keys)) => {
io::Result::Ok(key.into())
}
Some((rustls_pemfile::Item::Sec1Key(key), _keys)) => {
io::Result::Ok(key.into())
}
_ => io::Result::Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid key",
)),
}
}
None => io::Result::Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid key")),
},
Expand All @@ -73,37 +85,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let socket_addr: SocketAddr = "127.0.0.1:8802".parse()?;

let mut root_cert_store = rustls::RootCertStore::empty();
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
let ca_path = Path::new("./pki/ca.pem");
let mut pem = BufReader::new(File::open(ca_path)?);
let certs = rustls_pemfile::certs(&mut pem)?;
let trust_anchors = certs.iter().map(|cert| {
let ta = TrustAnchor::try_from_cert_der(&cert[..]).expect("cert should parse as anchor!");
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
});
root_cert_store.add_trust_anchors(trust_anchors);
let certs = rustls_pemfile::certs(&mut pem).collect::<Result<Vec<_>, _>>()?;
root_cert_store.add_parsable_certificates(certs);

let domain = "localhost";
let cert_path = Path::new("./pki/client.pem");
let key_path = Path::new("./pki/client.key");
let certs = load_certs(cert_path)?;
let mut keys = load_keys(key_path, None)?;
let key = load_keys(key_path, None)?;

let config = rustls::ClientConfig::builder()
.with_safe_defaults()
let config = tokio_rustls::rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_client_auth_cert(certs, keys.remove(0))
.with_client_auth_cert(certs, key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let connector = TlsConnector::from(Arc::new(config));

let stream = TcpStream::connect(&socket_addr).await?;
stream.set_nodelay(true)?;

let domain = rustls::ServerName::try_from(domain)
let domain = ServerName::try_from(domain)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;

let transport = connector.connect(domain, stream).await?;
Expand Down
62 changes: 33 additions & 29 deletions examples/tls-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,27 @@ use std::{
};

use pkcs8::der::Decode;
use pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls_pemfile::{certs, pkcs8_private_keys};
use tokio::net::{TcpListener, TcpStream};
use tokio_modbus::{prelude::*, server::tcp::Server};
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, PrivateKey};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use webpki::TrustAnchor;

fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())
fn load_certs(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
certs(&mut BufReader::new(File::open(path)?)).collect()
}

fn load_keys(path: &Path, password: Option<&str>) -> io::Result<Vec<PrivateKey>> {
fn load_keys(path: &Path, password: Option<&str>) -> io::Result<PrivateKeyDer<'static>> {
let expected_tag = match &password {
Some(_) => "ENCRYPTED PRIVATE KEY",
None => "PRIVATE KEY",
};

if expected_tag.eq("PRIVATE KEY") {
pkcs8_private_keys(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
.next()
.unwrap()
.map(Into::into)
} else {
let content = std::fs::read(path)?;
let mut iter = pem::parse_many(content)
Expand All @@ -60,9 +58,23 @@ fn load_keys(path: &Path, password: Option<&str>) -> io::Result<Vec<PrivateKey>>
io::Error::new(io::ErrorKind::InvalidData, err.to_string())
})?;
let key = decrypted.as_bytes().to_vec();
let key = rustls::PrivateKey(key);
let private_keys = vec![key];
io::Result::Ok(private_keys)
match rustls_pemfile::read_one_from_slice(&key)
.expect("cannot parse private key .pem file")
{
Some((rustls_pemfile::Item::Pkcs1Key(key), _keys)) => {
io::Result::Ok(key.into())
}
Some((rustls_pemfile::Item::Pkcs8Key(key), _keys)) => {
io::Result::Ok(key.into())
}
Some((rustls_pemfile::Item::Sec1Key(key), _keys)) => {
io::Result::Ok(key.into())
}
_ => io::Result::Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid key",
)),
}
}
None => io::Result::Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid key")),
},
Expand Down Expand Up @@ -189,11 +201,10 @@ async fn server_context(socket_addr: SocketAddr) -> anyhow::Result<()> {
let cert_path = Path::new("./pki/server.pem");
let key_path = Path::new("./pki/server.key");
let certs = load_certs(cert_path)?;
let mut keys = load_keys(key_path, None)?;
let key = load_keys(key_path, None)?;
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, keys.remove(0))
.with_single_cert(certs, key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let acceptor = TlsAcceptor::from(Arc::new(config));

Expand Down Expand Up @@ -222,35 +233,28 @@ async fn client_context(socket_addr: SocketAddr) {
let mut root_cert_store = rustls::RootCertStore::empty();
let ca_path = Path::new("./pki/ca.pem");
let mut pem = BufReader::new(File::open(ca_path).unwrap());
let certs = rustls_pemfile::certs(&mut pem).unwrap();
let trust_anchors = certs.iter().map(|cert| {
let ta = TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
});
root_cert_store.add_trust_anchors(trust_anchors);
let certs = rustls_pemfile::certs(&mut pem)
.collect::<Result<Vec<_>, _>>()
.unwrap();
root_cert_store.add_parsable_certificates(certs);

let domain = "localhost";
let cert_path = Path::new("./pki/client.pem");
let key_path = Path::new("./pki/client.key");
let certs = load_certs(cert_path).unwrap();
let mut keys = load_keys(key_path, None).unwrap();
let key = load_keys(key_path, None).unwrap();

let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_client_auth_cert(certs, keys.remove(0))
.with_client_auth_cert(certs, key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))
.unwrap();
let connector = TlsConnector::from(Arc::new(config));

let stream = TcpStream::connect(&socket_addr).await.unwrap();
stream.set_nodelay(true).unwrap();

let domain = rustls::ServerName::try_from(domain)
let domain = ServerName::try_from(domain)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))
.unwrap();

Expand Down

0 comments on commit c259882

Please sign in to comment.