diff --git a/src/connector/src/source/cdc/external/maybe_tls_connector.rs b/src/connector/src/source/cdc/external/maybe_tls_connector.rs new file mode 100644 index 0000000000000..add99716e20c6 --- /dev/null +++ b/src/connector/src/source/cdc/external/maybe_tls_connector.rs @@ -0,0 +1,182 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::task; +use std::error::Error; +use std::io; +use std::pin::Pin; +use std::task::Poll; + +use futures::{Future, FutureExt}; +use openssl::error::ErrorStack; +use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect}; +use tokio_postgres::NoTls; + +pub enum MaybeMakeTlsConnector { + NoTls(NoTls), + Tls(MakeTlsConnector), +} + +impl MakeTlsConnect for MaybeMakeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send, +{ + type Error = ErrorStack; + type Stream = MaybeTlsStream; + type TlsConnect = MaybeTlsConnector; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + match self { + MaybeMakeTlsConnector::NoTls(make_connector) => { + let connector = + >::make_tls_connect(make_connector, domain) + .expect("make NoTls connector always success"); + Ok(MaybeTlsConnector::NoTls(connector)) + } + MaybeMakeTlsConnector::Tls(make_connector) => { + >::make_tls_connect(make_connector, domain) + .map(MaybeTlsConnector::Tls) + } + } + } +} + +pub enum MaybeTlsConnector { + NoTls(NoTls), + Tls(TlsConnector), +} + +impl TlsConnect for MaybeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +{ + type Error = Box; + type Future = MaybeTlsFuture; + type Stream = MaybeTlsStream; + + fn connect(self, stream: S) -> Self::Future { + match self { + MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)), + MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin( + connector + .connect(stream) + .map(|x| x.map(|x| MaybeTlsStream::Tls(x))), + )), + } + } +} + +pub enum MaybeTlsStream { + NoTls(NoTlsStream), + Tls(TlsStream), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_read(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncRead>::poll_read(Pin::new(stream), cx, buf) + } + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_write(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_write(Pin::new(stream), cx, buf) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_flush(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_flush(Pin::new(stream), cx) + } + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_shutdown(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_shutdown(Pin::new(stream), cx) + } + } + } +} + +impl tls::TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> tls::ChannelBinding { + match self { + MaybeTlsStream::NoTls(stream) => stream.channel_binding(), + MaybeTlsStream::Tls(stream) => stream.channel_binding(), + } + } +} + +pub enum MaybeTlsFuture { + NoTls(NoTlsFuture), + Tls(Pin> + Send>>), +} + +impl Future for MaybeTlsFuture, E> +where + MaybeTlsStream: Sync + Send, + E: std::convert::From, +{ + type Output = Result, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match &mut *self { + MaybeTlsFuture::NoTls(fut) => fut + .poll_unpin(cx) + .map(|x| x.map(|x| MaybeTlsStream::NoTls(x))) + .map_err(|x| x.into()), + MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx), + } + } +} diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index d6a6c9cf96aec..ed549fd2e5e28 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -15,6 +15,9 @@ pub mod mock_external_table; mod postgres; +#[cfg(not(madsim))] +mod maybe_tls_connector; + use std::collections::HashMap; use std::fmt; diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index e5b550a24fe9e..7660546af14cb 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -12,197 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::task; use std::cmp::Ordering; use std::collections::HashMap; -use std::error::Error; -use std::io; -use std::pin::Pin; -use std::task::Poll; use anyhow::Context; use futures::stream::BoxStream; -use futures::{pin_mut, Future, FutureExt, StreamExt}; +use futures::{pin_mut, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; -use openssl::error::ErrorStack; use openssl::ssl::{SslConnector, SslMethod}; -use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream}; +use postgres_openssl::MakeTlsConnector; use risingwave_common::catalog::Schema; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::DatumRef; use serde_derive::{Deserialize, Serialize}; use thiserror_ext::AsReport; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect}; use tokio_postgres::types::PgLsn; use tokio_postgres::NoTls; use crate::error::{ConnectorError, ConnectorResult}; use crate::parser::postgres_row_to_owned_row; +#[cfg(not(madsim))] +use crate::source::cdc::external::maybe_tls_connector::MaybeMakeTlsConnector; use crate::source::cdc::external::{ CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader, SchemaTableName, SslMode, }; -enum MaybeMakeTlsConnector { - NoTls(NoTls), - Tls(MakeTlsConnector), -} - -impl MakeTlsConnect for MaybeMakeTlsConnector -where - S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send, -{ - type Error = ErrorStack; - type Stream = MaybeTlsStream; - type TlsConnect = MaybeTlsConnector; - - fn make_tls_connect(&mut self, domain: &str) -> Result { - match self { - MaybeMakeTlsConnector::NoTls(make_connector) => { - let connector = - >::make_tls_connect(make_connector, domain) - .expect("make NoTls connector always success"); - Ok(MaybeTlsConnector::NoTls(connector)) - } - MaybeMakeTlsConnector::Tls(make_connector) => { - >::make_tls_connect(make_connector, domain) - .map(MaybeTlsConnector::Tls) - } - } - } -} - -enum MaybeTlsConnector { - NoTls(NoTls), - Tls(TlsConnector), -} - -impl TlsConnect for MaybeTlsConnector -where - S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, -{ - type Error = Box; - type Future = MaybeTlsFuture; - type Stream = MaybeTlsStream; - - fn connect(self, stream: S) -> Self::Future { - match self { - MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)), - MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin( - connector - .connect(stream) - .map(|x| x.map(|x| MaybeTlsStream::Tls(x))), - )), - } - } -} - -enum MaybeTlsStream { - NoTls(NoTlsStream), - Tls(TlsStream), -} - -impl AsyncRead for MaybeTlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_read(Pin::new(stream), cx, buf) - } - MaybeTlsStream::Tls(stream) => { - as AsyncRead>::poll_read(Pin::new(stream), cx, buf) - } - } - } -} - -impl AsyncWrite for MaybeTlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &[u8], - ) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_write(Pin::new(stream), cx, buf) - } - MaybeTlsStream::Tls(stream) => { - as AsyncWrite>::poll_write(Pin::new(stream), cx, buf) - } - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_flush(Pin::new(stream), cx) - } - MaybeTlsStream::Tls(stream) => { - as AsyncWrite>::poll_flush(Pin::new(stream), cx) - } - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_shutdown(Pin::new(stream), cx) - } - MaybeTlsStream::Tls(stream) => { - as AsyncWrite>::poll_shutdown(Pin::new(stream), cx) - } - } - } -} - -impl tls::TlsStream for MaybeTlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn channel_binding(&self) -> tls::ChannelBinding { - match self { - MaybeTlsStream::NoTls(stream) => stream.channel_binding(), - MaybeTlsStream::Tls(stream) => stream.channel_binding(), - } - } -} - -enum MaybeTlsFuture { - NoTls(NoTlsFuture), - Tls(Pin> + Send>>), -} - -impl Future for MaybeTlsFuture, E> -where - MaybeTlsStream: Sync + Send, - E: std::convert::From, -{ - type Output = Result, E>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - match &mut *self { - MaybeTlsFuture::NoTls(fut) => fut - .poll_unpin(cx) - .map(|x| x.map(|x| MaybeTlsStream::NoTls(x))) - .map_err(|x| x.into()), - MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx), - } - } -} - -// ======= - #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] pub struct PostgresOffset { pub txid: i64, @@ -306,6 +142,7 @@ impl PostgresExternalTableReader { config.sslmode ); + #[cfg(not(madsim))] let connector = match config.sslmode { SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls), SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) { @@ -320,6 +157,8 @@ impl PostgresExternalTableReader { MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())) } }; + #[cfg(madsim)] + let connector = NoTls; let (client, connection) = tokio_postgres::connect(&database_url, connector).await?;