diff --git a/Cargo.lock b/Cargo.lock index 93596a423bf81..e6230a74278f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6858,9 +6858,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.60" +version = "0.10.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a4c6c3a2b158f7f8f2a2fc5a969fa3a068df6fc9dbb4a43845436e3af7c800" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ "bitflags 2.4.2", "cfg-if", @@ -6899,9 +6899,9 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.96" +version = "0.9.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f" +checksum = "dda2b0f344e78efc2facf7d195d098df0dd72151b26ab98da807afc26c198dff" dependencies = [ "cc", "libc", @@ -7636,6 +7636,19 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "postgres-openssl" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de0ea6504e07ca78355a6fb88ad0f36cafe9e696cbc6717f16a207f3a60be72" +dependencies = [ + "futures", + "openssl", + "tokio", + "tokio-openssl", + "tokio-postgres", +] + [[package]] name = "postgres-protocol" version = "0.6.6" @@ -9304,8 +9317,10 @@ dependencies = [ "nexmark", "num-bigint", "opendal", + "openssl", "parking_lot 0.12.1", "paste", + "postgres-openssl", "pretty_assertions", "prometheus", "prost 0.12.1", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 043e542e4a5c4..a34cdf04bc607 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -83,8 +83,10 @@ mysql_common = { version = "0.32", default-features = false, features = [ nexmark = { version = "0.2", features = ["serde"] } num-bigint = "0.4" opendal = "0.45" +openssl = "0.10" parking_lot = "0.12" paste = "1" +postgres-openssl = "0.5.0" prometheus = { version = "0.13", features = ["process"] } prost = { version = "0.12", features = ["no-recursion-limit"] } prost-reflect = "0.13" diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 8e4edda38249a..17885931f58e6 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -61,6 +61,8 @@ def_anyhow_newtype! { tokio_rustls::rustls::Error => "TLS error", rumqttc::v5::ClientError => "MQTT error", rumqttc::v5::OptionError => "MQTT error", + + openssl::error::ErrorStack => "OpenSSL error", } pub type ConnectorResult = std::result::Result; diff --git a/src/connector/src/source/cdc/external/maybe_tls_connector.rs b/src/connector/src/source/cdc/external/maybe_tls_connector.rs new file mode 100644 index 0000000000000..add99716e20c6 --- /dev/null +++ b/src/connector/src/source/cdc/external/maybe_tls_connector.rs @@ -0,0 +1,182 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::task; +use std::error::Error; +use std::io; +use std::pin::Pin; +use std::task::Poll; + +use futures::{Future, FutureExt}; +use openssl::error::ErrorStack; +use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect}; +use tokio_postgres::NoTls; + +pub enum MaybeMakeTlsConnector { + NoTls(NoTls), + Tls(MakeTlsConnector), +} + +impl MakeTlsConnect for MaybeMakeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send, +{ + type Error = ErrorStack; + type Stream = MaybeTlsStream; + type TlsConnect = MaybeTlsConnector; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + match self { + MaybeMakeTlsConnector::NoTls(make_connector) => { + let connector = + >::make_tls_connect(make_connector, domain) + .expect("make NoTls connector always success"); + Ok(MaybeTlsConnector::NoTls(connector)) + } + MaybeMakeTlsConnector::Tls(make_connector) => { + >::make_tls_connect(make_connector, domain) + .map(MaybeTlsConnector::Tls) + } + } + } +} + +pub enum MaybeTlsConnector { + NoTls(NoTls), + Tls(TlsConnector), +} + +impl TlsConnect for MaybeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +{ + type Error = Box; + type Future = MaybeTlsFuture; + type Stream = MaybeTlsStream; + + fn connect(self, stream: S) -> Self::Future { + match self { + MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)), + MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin( + connector + .connect(stream) + .map(|x| x.map(|x| MaybeTlsStream::Tls(x))), + )), + } + } +} + +pub enum MaybeTlsStream { + NoTls(NoTlsStream), + Tls(TlsStream), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_read(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncRead>::poll_read(Pin::new(stream), cx, buf) + } + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_write(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_write(Pin::new(stream), cx, buf) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_flush(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_flush(Pin::new(stream), cx) + } + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_shutdown(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_shutdown(Pin::new(stream), cx) + } + } + } +} + +impl tls::TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> tls::ChannelBinding { + match self { + MaybeTlsStream::NoTls(stream) => stream.channel_binding(), + MaybeTlsStream::Tls(stream) => stream.channel_binding(), + } + } +} + +pub enum MaybeTlsFuture { + NoTls(NoTlsFuture), + Tls(Pin> + Send>>), +} + +impl Future for MaybeTlsFuture, E> +where + MaybeTlsStream: Sync + Send, + E: std::convert::From, +{ + type Output = Result, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match &mut *self { + MaybeTlsFuture::NoTls(fut) => fut + .poll_unpin(cx) + .map(|x| x.map(|x| MaybeTlsStream::NoTls(x))) + .map_err(|x| x.into()), + MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx), + } + } +} diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index 7de085fceb7ce..ed549fd2e5e28 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -15,7 +15,11 @@ pub mod mock_external_table; mod postgres; +#[cfg(not(madsim))] +mod maybe_tls_connector; + use std::collections::HashMap; +use std::fmt; use anyhow::Context; use futures::stream::BoxStream; @@ -245,6 +249,35 @@ pub struct ExternalTableConfig { pub schema: String, #[serde(rename = "table.name")] pub table: String, + /// `ssl.mode` specifies the SSL/TLS encryption level for secure communication with Postgres. + /// Choices include `disable`, `prefer`, and `require`. + /// This field is optional. `prefer` is used if not specified. + #[serde(rename = "ssl.mode", default = "Default::default")] + pub sslmode: SslMode, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SslMode { + Disable, + Prefer, + Require, +} + +impl Default for SslMode { + fn default() -> Self { + Self::Prefer + } +} + +impl fmt::Display for SslMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + SslMode::Disable => "disable", + SslMode::Prefer => "prefer", + SslMode::Require => "require", + }) + } } impl ExternalTableReader for MySqlExternalTableReader { diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index 0880faa902a26..7660546af14cb 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -20,6 +20,8 @@ use futures::stream::BoxStream; use futures::{pin_mut, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; +use openssl::ssl::{SslConnector, SslMethod}; +use postgres_openssl::MakeTlsConnector; use risingwave_common::catalog::Schema; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::DatumRef; @@ -30,9 +32,11 @@ use tokio_postgres::NoTls; use crate::error::{ConnectorError, ConnectorResult}; use crate::parser::postgres_row_to_owned_row; +#[cfg(not(madsim))] +use crate::source::cdc::external::maybe_tls_connector::MaybeMakeTlsConnector; use crate::source::cdc::external::{ CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader, - SchemaTableName, + SchemaTableName, SslMode, }; #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] @@ -129,11 +133,34 @@ impl PostgresExternalTableReader { .context("failed to extract postgres connector properties")?; let database_url = format!( - "postgresql://{}:{}@{}:{}/{}", - config.username, config.password, config.host, config.port, config.database + "postgresql://{}:{}@{}:{}/{}?sslmode={}", + config.username, + config.password, + config.host, + config.port, + config.database, + config.sslmode ); - let (client, connection) = tokio_postgres::connect(&database_url, NoTls).await?; + #[cfg(not(madsim))] + let connector = match config.sslmode { + SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls), + SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) { + Ok(builder) => MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())), + Err(e) => { + tracing::warn!(error = %e.as_report(), "SSL connector error"); + MaybeMakeTlsConnector::NoTls(NoTls) + } + }, + SslMode::Require => { + let builder = SslConnector::builder(SslMethod::tls())?; + MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())) + } + }; + #[cfg(madsim)] + let connector = NoTls; + + let (client, connection) = tokio_postgres::connect(&database_url, connector).await?; tokio::spawn(async move { if let Err(e) = connection.await {