diff --git a/Cargo.lock b/Cargo.lock index 4392fe9e5f28..cecf0201a5af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9343,7 +9343,7 @@ dependencies = [ "mysql_common", "nexmark", "num-bigint", - "opendal", + "opendal 0.44.2", "openssl", "parking_lot 0.12.1", "paste", diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index ad4af35977b4..788af27a7e5b 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -12,30 +12,196 @@ // See the License for the specific language governing permissions and // limitations under the License. +use core::task; use std::cmp::Ordering; use std::collections::HashMap; +use std::error::Error; +use std::io; +use std::pin::Pin; +use std::task::Poll; use anyhow::Context; use futures::stream::BoxStream; -use futures::{pin_mut, StreamExt}; +use futures::{pin_mut, Future, FutureExt, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; +use openssl::error::ErrorStack; use openssl::ssl::{SslConnector, SslMethod}; -use postgres_openssl::MakeTlsConnector; +use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream}; use risingwave_common::catalog::Schema; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::DatumRef; use serde_derive::{Deserialize, Serialize}; use thiserror_ext::AsReport; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect}; use tokio_postgres::types::PgLsn; +use tokio_postgres::NoTls; use crate::error::{ConnectorError, ConnectorResult}; use crate::parser::postgres_row_to_owned_row; use crate::source::cdc::external::{ CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader, - SchemaTableName, + SchemaTableName, SslMode, }; +enum MaybeMakeTlsConnector { + NoTls(NoTls), + Tls(MakeTlsConnector), +} + +impl MakeTlsConnect for MaybeMakeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send, +{ + type Error = ErrorStack; + type Stream = MaybeTlsStream; + type TlsConnect = MaybeTlsConnector; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + match self { + MaybeMakeTlsConnector::NoTls(make_connector) => { + let connector = + >::make_tls_connect(make_connector, domain) + .expect("make NoTls connector always success"); + Ok(MaybeTlsConnector::NoTls(connector)) + } + MaybeMakeTlsConnector::Tls(make_connector) => { + >::make_tls_connect(make_connector, domain) + .map(MaybeTlsConnector::Tls) + } + } + } +} + +enum MaybeTlsConnector { + NoTls(NoTls), + Tls(TlsConnector), +} + +impl TlsConnect for MaybeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Error = Box; + type Future = MaybeTlsFuture, Self::Error>; + type Stream = MaybeTlsStream; + + fn connect(self, stream: S) -> Self::Future { + match self { + MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)), + MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin( + connector + .connect(stream) + .map(|x| x.map(|x| MaybeTlsStream::Tls(x))), + )), + } + } +} + +enum MaybeTlsStream { + NoTls(NoTlsStream), + Tls(TlsStream), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_read(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncRead>::poll_read(Pin::new(stream), cx, buf) + } + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_write(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_write(Pin::new(stream), cx, buf) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_flush(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_flush(Pin::new(stream), cx) + } + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_shutdown(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_shutdown(Pin::new(stream), cx) + } + } + } +} + +impl tls::TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> tls::ChannelBinding { + match self { + MaybeTlsStream::NoTls(stream) => stream.channel_binding(), + MaybeTlsStream::Tls(stream) => stream.channel_binding(), + } + } +} + +enum MaybeTlsFuture { + NoTls(NoTlsFuture), + Tls(Pin> + Send>>), +} + +impl Future for MaybeTlsFuture, E> +where + E: std::convert::From, +{ + type Output = Result, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match &mut *self { + MaybeTlsFuture::NoTls(fut) => fut + .poll_unpin(cx) + .map(|x| x.map(|x| MaybeTlsStream::NoTls(x))) + .map_err(|x| x.into()), + MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx), + } + } +} + +// ======= + #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] pub struct PostgresOffset { pub txid: i64, @@ -139,8 +305,20 @@ impl PostgresExternalTableReader { config.sslmode ); - let builder = SslConnector::builder(SslMethod::tls())?; - let connector = MakeTlsConnector::new(builder.build()); + let connector = match config.sslmode { + SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls), + SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) { + Ok(builder) => MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())), + Err(e) => { + tracing::warn!(error = %e.as_report(), "SSL connector error"); + MaybeMakeTlsConnector::NoTls(NoTls) + } + }, + SslMode::Require => { + let builder = SslConnector::builder(SslMethod::tls())?; + MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())) + } + }; let (client, connection) = tokio_postgres::connect(&database_url, connector).await?;