Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connector): add SSL support for external cdc postgres connector #15690

Merged
merged 10 commits into from
Mar 28, 2024
23 changes: 19 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ mysql_common = { version = "0.31", default-features = false, features = [
nexmark = { version = "0.2", features = ["serde"] }
num-bigint = "0.4"
opendal = "0.44"
openssl = "0.10"
parking_lot = "0.12"
paste = "1"
postgres-openssl = "0.5.0"
prometheus = { version = "0.13", features = ["process"] }
prost = { version = "0.12", features = ["no-recursion-limit"] }
prost-reflect = "0.13"
Expand Down
2 changes: 2 additions & 0 deletions src/connector/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def_anyhow_newtype! {
tokio_rustls::rustls::Error => "TLS error",
rumqttc::v5::ClientError => "MQTT error",
rumqttc::v5::OptionError => "MQTT error",

openssl::error::ErrorStack => "OpenSSL error",
}

pub type ConnectorResult<T, E = ConnectorError> = std::result::Result<T, E>;
Expand Down
30 changes: 30 additions & 0 deletions src/connector/src/source/cdc/external/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod mock_external_table;
mod postgres;

use std::collections::HashMap;
use std::fmt;

use anyhow::Context;
use futures::stream::BoxStream;
Expand Down Expand Up @@ -245,6 +246,35 @@ pub struct ExternalTableConfig {
pub schema: String,
#[serde(rename = "table.name")]
pub table: String,
/// `ssl.mode` specifies the SSL/TLS encryption level for secure communication with Postgres.
/// Choices include `disable`, `prefer`, and `require`.
/// This field is optional. `prefer` is used if not specified.
#[serde(rename = "ssl.mode", default = "Default::default")]
pub sslmode: SslMode,
neverchanje marked this conversation as resolved.
Show resolved Hide resolved
}

#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SslMode {
Disable,
Prefer,
Require,
}

impl Default for SslMode {
fn default() -> Self {
Self::Prefer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set the default value to Disable which is same as before for backward compatible. Then for upstreams that require TLS, the user should set ssl.mode='require'.

}
}

impl fmt::Display for SslMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
SslMode::Disable => "disable",
SslMode::Prefer => "prefer",
SslMode::Require => "require",
})
}
}

impl ExternalTableReader for MySqlExternalTableReader {
Expand Down
197 changes: 192 additions & 5 deletions src/connector/src/source/cdc/external/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,196 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use core::task;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::error::Error;
use std::io;
use std::pin::Pin;
use std::task::Poll;

use anyhow::Context;
use futures::stream::BoxStream;
use futures::{pin_mut, StreamExt};
use futures::{pin_mut, Future, FutureExt, StreamExt};
use futures_async_stream::try_stream;
use itertools::Itertools;
use openssl::error::ErrorStack;
use openssl::ssl::{SslConnector, SslMethod};
use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream};
use risingwave_common::catalog::Schema;
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::types::DatumRef;
use serde_derive::{Deserialize, Serialize};
use thiserror_ext::AsReport;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect};
use tokio_postgres::types::PgLsn;
use tokio_postgres::NoTls;

use crate::error::{ConnectorError, ConnectorResult};
use crate::parser::postgres_row_to_owned_row;
use crate::source::cdc::external::{
CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader,
SchemaTableName,
SchemaTableName, SslMode,
};

enum MaybeMakeTlsConnector {
NoTls(NoTls),
Tls(MakeTlsConnector),
}

impl<S> MakeTlsConnect<S> for MaybeMakeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send,
{
type Error = ErrorStack;
type Stream = MaybeTlsStream<S>;
type TlsConnect = MaybeTlsConnector;

fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
match self {
MaybeMakeTlsConnector::NoTls(make_connector) => {
let connector =
<NoTls as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
.expect("make NoTls connector always success");
Ok(MaybeTlsConnector::NoTls(connector))
}
MaybeMakeTlsConnector::Tls(make_connector) => {
<MakeTlsConnector as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
.map(MaybeTlsConnector::Tls)
}
}
}
}

enum MaybeTlsConnector {
NoTls(NoTls),
Tls(TlsConnector),
}

impl<S> TlsConnect<S> for MaybeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Error = Box<dyn Error + Send + Sync>;
type Future = MaybeTlsFuture<MaybeTlsStream<S>, Self::Error>;
type Stream = MaybeTlsStream<S>;

fn connect(self, stream: S) -> Self::Future {
match self {
MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)),
MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin(
connector
.connect(stream)
.map(|x| x.map(|x| MaybeTlsStream::Tls(x))),
)),
}
}
}

enum MaybeTlsStream<S> {
NoTls(NoTlsStream),
Tls(TlsStream<S>),
}

impl<S> AsyncRead for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
}
}
}
}

impl<S> AsyncWrite for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_flush(Pin::new(stream), cx)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_flush(Pin::new(stream), cx)
}
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
}
}
}
}

impl<S> tls::TlsStream for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> tls::ChannelBinding {
match self {
MaybeTlsStream::NoTls(stream) => stream.channel_binding(),
MaybeTlsStream::Tls(stream) => stream.channel_binding(),
}
}
}

enum MaybeTlsFuture<S, E> {
NoTls(NoTlsFuture),
Tls(Pin<Box<dyn Future<Output = Result<S, E>> + Send>>),
}

impl<S, E> Future for MaybeTlsFuture<MaybeTlsStream<S>, E>
where
E: std::convert::From<tokio_postgres::tls::NoTlsError>,
{
type Output = Result<MaybeTlsStream<S>, E>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match &mut *self {
MaybeTlsFuture::NoTls(fut) => fut
.poll_unpin(cx)
.map(|x| x.map(|x| MaybeTlsStream::NoTls(x)))
.map_err(|x| x.into()),
MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx),
}
}
}

// =======

#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct PostgresOffset {
pub txid: i64,
Expand Down Expand Up @@ -129,11 +296,31 @@ impl PostgresExternalTableReader {
.context("failed to extract postgres connector properties")?;

let database_url = format!(
"postgresql://{}:{}@{}:{}/{}",
config.username, config.password, config.host, config.port, config.database
"postgresql://{}:{}@{}:{}/{}?sslmode={}",
config.username,
config.password,
config.host,
config.port,
config.database,
config.sslmode
);

let (client, connection) = tokio_postgres::connect(&database_url, NoTls).await?;
let connector = match config.sslmode {
SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls),
SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) {
Ok(builder) => MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())),
Err(e) => {
tracing::warn!(error = %e.as_report(), "SSL connector error");
MaybeMakeTlsConnector::NoTls(NoTls)
}
},
SslMode::Require => {
let builder = SslConnector::builder(SslMethod::tls())?;
MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
}
};

let (client, connection) = tokio_postgres::connect(&database_url, connector).await?;

tokio::spawn(async move {
if let Err(e) = connection.await {
Expand Down
Loading