Skip to content

Commit

Permalink
feat(connector): add SSL support for external cdc postgres connector (#…
Browse files Browse the repository at this point in the history
…15690)

Co-authored-by: Tao Wu <[email protected]>
Co-authored-by: StrikeW <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2024
1 parent 26a8ebe commit dd1249d
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 8 deletions.
23 changes: 19 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ mysql_common = { version = "0.32", default-features = false, features = [
nexmark = { version = "0.2", features = ["serde"] }
num-bigint = "0.4"
opendal = "0.45"
openssl = "0.10"
parking_lot = "0.12"
paste = "1"
postgres-openssl = "0.5.0"
prometheus = { version = "0.13", features = ["process"] }
prost = { version = "0.12", features = ["no-recursion-limit"] }
prost-reflect = "0.13"
Expand Down
2 changes: 2 additions & 0 deletions src/connector/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def_anyhow_newtype! {
tokio_rustls::rustls::Error => "TLS error",
rumqttc::v5::ClientError => "MQTT error",
rumqttc::v5::OptionError => "MQTT error",

openssl::error::ErrorStack => "OpenSSL error",
}

pub type ConnectorResult<T, E = ConnectorError> = std::result::Result<T, E>;
Expand Down
182 changes: 182 additions & 0 deletions src/connector/src/source/cdc/external/maybe_tls_connector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use core::task;
use std::error::Error;
use std::io;
use std::pin::Pin;
use std::task::Poll;

use futures::{Future, FutureExt};
use openssl::error::ErrorStack;
use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect};
use tokio_postgres::NoTls;

pub enum MaybeMakeTlsConnector {
NoTls(NoTls),
Tls(MakeTlsConnector),
}

impl<S> MakeTlsConnect<S> for MaybeMakeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send,
{
type Error = ErrorStack;
type Stream = MaybeTlsStream<S>;
type TlsConnect = MaybeTlsConnector;

fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
match self {
MaybeMakeTlsConnector::NoTls(make_connector) => {
let connector =
<NoTls as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
.expect("make NoTls connector always success");
Ok(MaybeTlsConnector::NoTls(connector))
}
MaybeMakeTlsConnector::Tls(make_connector) => {
<MakeTlsConnector as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
.map(MaybeTlsConnector::Tls)
}
}
}
}

pub enum MaybeTlsConnector {
NoTls(NoTls),
Tls(TlsConnector),
}

impl<S> TlsConnect<S> for MaybeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
type Error = Box<dyn Error + Send + Sync>;
type Future = MaybeTlsFuture<Self::Stream, Self::Error>;
type Stream = MaybeTlsStream<S>;

fn connect(self, stream: S) -> Self::Future {
match self {
MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)),
MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin(
connector
.connect(stream)
.map(|x| x.map(|x| MaybeTlsStream::Tls(x))),
)),
}
}
}

pub enum MaybeTlsStream<S> {
NoTls(NoTlsStream),
Tls(TlsStream<S>),
}

impl<S> AsyncRead for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
}
}
}
}

impl<S> AsyncWrite for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_flush(Pin::new(stream), cx)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_flush(Pin::new(stream), cx)
}
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
}
}
}
}

impl<S> tls::TlsStream for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> tls::ChannelBinding {
match self {
MaybeTlsStream::NoTls(stream) => stream.channel_binding(),
MaybeTlsStream::Tls(stream) => stream.channel_binding(),
}
}
}

pub enum MaybeTlsFuture<S, E> {
NoTls(NoTlsFuture),
Tls(Pin<Box<dyn Future<Output = Result<S, E>> + Send>>),
}

impl<S, E> Future for MaybeTlsFuture<MaybeTlsStream<S>, E>
where
MaybeTlsStream<S>: Sync + Send,
E: std::convert::From<tokio_postgres::tls::NoTlsError>,
{
type Output = Result<MaybeTlsStream<S>, E>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match &mut *self {
MaybeTlsFuture::NoTls(fut) => fut
.poll_unpin(cx)
.map(|x| x.map(|x| MaybeTlsStream::NoTls(x)))
.map_err(|x| x.into()),
MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx),
}
}
}
33 changes: 33 additions & 0 deletions src/connector/src/source/cdc/external/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
pub mod mock_external_table;
mod postgres;

#[cfg(not(madsim))]
mod maybe_tls_connector;

use std::collections::HashMap;
use std::fmt;

use anyhow::Context;
use futures::stream::BoxStream;
Expand Down Expand Up @@ -245,6 +249,35 @@ pub struct ExternalTableConfig {
pub schema: String,
#[serde(rename = "table.name")]
pub table: String,
/// `ssl.mode` specifies the SSL/TLS encryption level for secure communication with Postgres.
/// Choices include `disable`, `prefer`, and `require`.
/// This field is optional. `prefer` is used if not specified.
#[serde(rename = "ssl.mode", default = "Default::default")]
pub sslmode: SslMode,
}

#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SslMode {
Disable,
Prefer,
Require,
}

impl Default for SslMode {
fn default() -> Self {
Self::Prefer
}
}

impl fmt::Display for SslMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
SslMode::Disable => "disable",
SslMode::Prefer => "prefer",
SslMode::Require => "require",
})
}
}

impl ExternalTableReader for MySqlExternalTableReader {
Expand Down
35 changes: 31 additions & 4 deletions src/connector/src/source/cdc/external/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use futures::stream::BoxStream;
use futures::{pin_mut, StreamExt};
use futures_async_stream::try_stream;
use itertools::Itertools;
use openssl::ssl::{SslConnector, SslMethod};
use postgres_openssl::MakeTlsConnector;
use risingwave_common::catalog::Schema;
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::types::DatumRef;
Expand All @@ -30,9 +32,11 @@ use tokio_postgres::NoTls;

use crate::error::{ConnectorError, ConnectorResult};
use crate::parser::postgres_row_to_owned_row;
#[cfg(not(madsim))]
use crate::source::cdc::external::maybe_tls_connector::MaybeMakeTlsConnector;
use crate::source::cdc::external::{
CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader,
SchemaTableName,
SchemaTableName, SslMode,
};

#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -129,11 +133,34 @@ impl PostgresExternalTableReader {
.context("failed to extract postgres connector properties")?;

let database_url = format!(
"postgresql://{}:{}@{}:{}/{}",
config.username, config.password, config.host, config.port, config.database
"postgresql://{}:{}@{}:{}/{}?sslmode={}",
config.username,
config.password,
config.host,
config.port,
config.database,
config.sslmode
);

let (client, connection) = tokio_postgres::connect(&database_url, NoTls).await?;
#[cfg(not(madsim))]
let connector = match config.sslmode {
SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls),
SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) {
Ok(builder) => MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())),
Err(e) => {
tracing::warn!(error = %e.as_report(), "SSL connector error");
MaybeMakeTlsConnector::NoTls(NoTls)
}
},
SslMode::Require => {
let builder = SslConnector::builder(SslMethod::tls())?;
MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
}
};
#[cfg(madsim)]
let connector = NoTls;

let (client, connection) = tokio_postgres::connect(&database_url, connector).await?;

tokio::spawn(async move {
if let Err(e) = connection.await {
Expand Down

0 comments on commit dd1249d

Please sign in to comment.