From eb60978b508eb82407b1d2acbc9b13614bad96b0 Mon Sep 17 00:00:00 2001 From: Jinser Kafka Date: Tue, 12 Mar 2024 17:04:24 +0800 Subject: [PATCH 1/7] feat: add sslmode to external table config --- src/connector/src/source/cdc/external/mod.rs | 27 +++++++++++++++++++ .../src/source/cdc/external/postgres.rs | 9 +++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index 7de085fceb7ce..1f182967f7cf2 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -16,6 +16,7 @@ pub mod mock_external_table; mod postgres; use std::collections::HashMap; +use std::fmt; use anyhow::Context; use futures::stream::BoxStream; @@ -245,6 +246,32 @@ pub struct ExternalTableConfig { pub schema: String, #[serde(rename = "table.name")] pub table: String, + #[serde(rename = "ssl.name", default = "Default::default")] + pub sslmode: SslMode, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SslMode { + Disable, + Prefer, + Require, +} + +impl Default for SslMode { + fn default() -> Self { + Self::Disable + } +} + +impl fmt::Display for SslMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + SslMode::Disable => "disable", + SslMode::Prefer => "prefer", + SslMode::Require => "require", + }) + } } impl ExternalTableReader for MySqlExternalTableReader { diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index 0880faa902a26..51e45a057e561 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -129,8 +129,13 @@ impl PostgresExternalTableReader { .context("failed to extract postgres connector properties")?; let database_url = format!( - "postgresql://{}:{}@{}:{}/{}", - config.username, config.password, config.host, config.port, config.database + "postgresql://{}:{}@{}:{}/{}?sslmode={}", + config.username, + config.password, + config.host, + config.port, + config.database, + dbg!(&config.sslmode) ); let (client, connection) = tokio_postgres::connect(&database_url, NoTls).await?; From 425a11954fd625674c845ccf2697509fec2c98c8 Mon Sep 17 00:00:00 2001 From: Jinser Kafka Date: Thu, 14 Mar 2024 16:24:40 +0800 Subject: [PATCH 2/7] feat: add support for SSL mode in postgres connector --- Cargo.lock | 25 +++++++++++++++---- src/connector/Cargo.toml | 2 ++ src/connector/src/error.rs | 2 ++ src/connector/src/source/cdc/external/mod.rs | 2 +- .../src/source/cdc/external/postgres.rs | 10 +++++--- 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4584f80a54ad4..4392fe9e5f289 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6897,9 +6897,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.60" +version = "0.10.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a4c6c3a2b158f7f8f2a2fc5a969fa3a068df6fc9dbb4a43845436e3af7c800" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ "bitflags 2.4.2", "cfg-if", @@ -6938,9 +6938,9 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.96" +version = "0.9.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f" +checksum = "dda2b0f344e78efc2facf7d195d098df0dd72151b26ab98da807afc26c198dff" dependencies = [ "cc", "libc", @@ -7689,6 +7689,19 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "postgres-openssl" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de0ea6504e07ca78355a6fb88ad0f36cafe9e696cbc6717f16a207f3a60be72" +dependencies = [ + "futures", + "openssl", + "tokio", + "tokio-openssl", + "tokio-postgres", +] + [[package]] name = "postgres-protocol" version = "0.6.6" @@ -9330,9 +9343,11 @@ dependencies = [ "mysql_common", "nexmark", "num-bigint", - "opendal 0.44.2", + "opendal", + "openssl", "parking_lot 0.12.1", "paste", + "postgres-openssl", "pretty_assertions", "prometheus", "prost 0.12.1", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index dc5d71807260d..555fd2bf4bea8 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -83,8 +83,10 @@ mysql_common = { version = "0.31", default-features = false, features = [ nexmark = { version = "0.2", features = ["serde"] } num-bigint = "0.4" opendal = "0.44" +openssl = "0.10" parking_lot = "0.12" paste = "1" +postgres-openssl = "0.5.0" prometheus = { version = "0.13", features = ["process"] } prost = { version = "0.12", features = ["no-recursion-limit"] } prost-reflect = "0.13" diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 8e4edda38249a..17885931f58e6 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -61,6 +61,8 @@ def_anyhow_newtype! { tokio_rustls::rustls::Error => "TLS error", rumqttc::v5::ClientError => "MQTT error", rumqttc::v5::OptionError => "MQTT error", + + openssl::error::ErrorStack => "OpenSSL error", } pub type ConnectorResult = std::result::Result; diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index 1f182967f7cf2..fc5226a3bf183 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -246,7 +246,7 @@ pub struct ExternalTableConfig { pub schema: String, #[serde(rename = "table.name")] pub table: String, - #[serde(rename = "ssl.name", default = "Default::default")] + #[serde(rename = "ssl.mode", default = "Default::default")] pub sslmode: SslMode, } diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index 51e45a057e561..ad4af35977b46 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -20,13 +20,14 @@ use futures::stream::BoxStream; use futures::{pin_mut, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; +use openssl::ssl::{SslConnector, SslMethod}; +use postgres_openssl::MakeTlsConnector; use risingwave_common::catalog::Schema; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::DatumRef; use serde_derive::{Deserialize, Serialize}; use thiserror_ext::AsReport; use tokio_postgres::types::PgLsn; -use tokio_postgres::NoTls; use crate::error::{ConnectorError, ConnectorResult}; use crate::parser::postgres_row_to_owned_row; @@ -135,10 +136,13 @@ impl PostgresExternalTableReader { config.host, config.port, config.database, - dbg!(&config.sslmode) + config.sslmode ); - let (client, connection) = tokio_postgres::connect(&database_url, NoTls).await?; + let builder = SslConnector::builder(SslMethod::tls())?; + let connector = MakeTlsConnector::new(builder.build()); + + let (client, connection) = tokio_postgres::connect(&database_url, connector).await?; tokio::spawn(async move { if let Err(e) = connection.await { From 9cf2811fab2759a29cf4445442e8a830b454bbb2 Mon Sep 17 00:00:00 2001 From: Jinser Kafka Date: Thu, 14 Mar 2024 16:51:13 +0800 Subject: [PATCH 3/7] feat: default ssl.mode to `prefer` --- src/connector/src/source/cdc/external/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index fc5226a3bf183..40b9f97bb5c2a 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -260,7 +260,7 @@ pub enum SslMode { impl Default for SslMode { fn default() -> Self { - Self::Disable + Self::Prefer } } From 46a9074afbe08bf241f8fed1710256b24bc7fb78 Mon Sep 17 00:00:00 2001 From: Jinser Kafka Date: Tue, 19 Mar 2024 19:32:36 +0800 Subject: [PATCH 4/7] fix: add MaybeMakeTlsConnector to support no tls only when sslmode is disable, we should use NoTls, and we should use Tls when sslmode is prefer or require. --- Cargo.lock | 2 +- .../src/source/cdc/external/postgres.rs | 188 +++++++++++++++++- 2 files changed, 184 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4392fe9e5f289..cecf0201a5af9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9343,7 +9343,7 @@ dependencies = [ "mysql_common", "nexmark", "num-bigint", - "opendal", + "opendal 0.44.2", "openssl", "parking_lot 0.12.1", "paste", diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index ad4af35977b46..788af27a7e5b1 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -12,30 +12,196 @@ // See the License for the specific language governing permissions and // limitations under the License. +use core::task; use std::cmp::Ordering; use std::collections::HashMap; +use std::error::Error; +use std::io; +use std::pin::Pin; +use std::task::Poll; use anyhow::Context; use futures::stream::BoxStream; -use futures::{pin_mut, StreamExt}; +use futures::{pin_mut, Future, FutureExt, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; +use openssl::error::ErrorStack; use openssl::ssl::{SslConnector, SslMethod}; -use postgres_openssl::MakeTlsConnector; +use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream}; use risingwave_common::catalog::Schema; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::DatumRef; use serde_derive::{Deserialize, Serialize}; use thiserror_ext::AsReport; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect}; use tokio_postgres::types::PgLsn; +use tokio_postgres::NoTls; use crate::error::{ConnectorError, ConnectorResult}; use crate::parser::postgres_row_to_owned_row; use crate::source::cdc::external::{ CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader, - SchemaTableName, + SchemaTableName, SslMode, }; +enum MaybeMakeTlsConnector { + NoTls(NoTls), + Tls(MakeTlsConnector), +} + +impl MakeTlsConnect for MaybeMakeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send, +{ + type Error = ErrorStack; + type Stream = MaybeTlsStream; + type TlsConnect = MaybeTlsConnector; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + match self { + MaybeMakeTlsConnector::NoTls(make_connector) => { + let connector = + >::make_tls_connect(make_connector, domain) + .expect("make NoTls connector always success"); + Ok(MaybeTlsConnector::NoTls(connector)) + } + MaybeMakeTlsConnector::Tls(make_connector) => { + >::make_tls_connect(make_connector, domain) + .map(MaybeTlsConnector::Tls) + } + } + } +} + +enum MaybeTlsConnector { + NoTls(NoTls), + Tls(TlsConnector), +} + +impl TlsConnect for MaybeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Error = Box; + type Future = MaybeTlsFuture, Self::Error>; + type Stream = MaybeTlsStream; + + fn connect(self, stream: S) -> Self::Future { + match self { + MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)), + MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin( + connector + .connect(stream) + .map(|x| x.map(|x| MaybeTlsStream::Tls(x))), + )), + } + } +} + +enum MaybeTlsStream { + NoTls(NoTlsStream), + Tls(TlsStream), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_read(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncRead>::poll_read(Pin::new(stream), cx, buf) + } + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_write(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_write(Pin::new(stream), cx, buf) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_flush(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_flush(Pin::new(stream), cx) + } + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_shutdown(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_shutdown(Pin::new(stream), cx) + } + } + } +} + +impl tls::TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> tls::ChannelBinding { + match self { + MaybeTlsStream::NoTls(stream) => stream.channel_binding(), + MaybeTlsStream::Tls(stream) => stream.channel_binding(), + } + } +} + +enum MaybeTlsFuture { + NoTls(NoTlsFuture), + Tls(Pin> + Send>>), +} + +impl Future for MaybeTlsFuture, E> +where + E: std::convert::From, +{ + type Output = Result, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match &mut *self { + MaybeTlsFuture::NoTls(fut) => fut + .poll_unpin(cx) + .map(|x| x.map(|x| MaybeTlsStream::NoTls(x))) + .map_err(|x| x.into()), + MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx), + } + } +} + +// ======= + #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] pub struct PostgresOffset { pub txid: i64, @@ -139,8 +305,20 @@ impl PostgresExternalTableReader { config.sslmode ); - let builder = SslConnector::builder(SslMethod::tls())?; - let connector = MakeTlsConnector::new(builder.build()); + let connector = match config.sslmode { + SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls), + SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) { + Ok(builder) => MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())), + Err(e) => { + tracing::warn!(error = %e.as_report(), "SSL connector error"); + MaybeMakeTlsConnector::NoTls(NoTls) + } + }, + SslMode::Require => { + let builder = SslConnector::builder(SslMethod::tls())?; + MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())) + } + }; let (client, connection) = tokio_postgres::connect(&database_url, connector).await?; From 5cfb695abf461e558ccb9aaa8e72d2ace450af9c Mon Sep 17 00:00:00 2001 From: Jinser Kafka Date: Wed, 20 Mar 2024 09:48:46 +0800 Subject: [PATCH 5/7] chore: add comment to `ssl.mode` field in ExternalTableConfig --- src/connector/src/source/cdc/external/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index 40b9f97bb5c2a..d6a6c9cf96aec 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -246,6 +246,9 @@ pub struct ExternalTableConfig { pub schema: String, #[serde(rename = "table.name")] pub table: String, + /// `ssl.mode` specifies the SSL/TLS encryption level for secure communication with Postgres. + /// Choices include `disable`, `prefer`, and `require`. + /// This field is optional. `prefer` is used if not specified. #[serde(rename = "ssl.mode", default = "Default::default")] pub sslmode: SslMode, } From 749e38c4160745680636aadba9e9989063e34531 Mon Sep 17 00:00:00 2001 From: Jinser Kafka Date: Fri, 22 Mar 2024 15:10:53 +0800 Subject: [PATCH 6/7] fix: add Sync to TlsConnect trait bound --- src/connector/src/source/cdc/external/postgres.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index 788af27a7e5b1..e5b550a24fe9e 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -81,10 +81,10 @@ enum MaybeTlsConnector { impl TlsConnect for MaybeTlsConnector where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, { type Error = Box; - type Future = MaybeTlsFuture, Self::Error>; + type Future = MaybeTlsFuture; type Stream = MaybeTlsStream; fn connect(self, stream: S) -> Self::Future { @@ -185,6 +185,7 @@ enum MaybeTlsFuture { impl Future for MaybeTlsFuture, E> where + MaybeTlsStream: Sync + Send, E: std::convert::From, { type Output = Result, E>; From 53a02901915385b6433267392906685083fd24ed Mon Sep 17 00:00:00 2001 From: Jinser Kafka Date: Wed, 27 Mar 2024 00:49:49 +0800 Subject: [PATCH 7/7] fix: only enable MaybeTlsConnector when not madsim --- .../cdc/external/maybe_tls_connector.rs | 182 ++++++++++++++++++ src/connector/src/source/cdc/external/mod.rs | 3 + .../src/source/cdc/external/postgres.rs | 175 +---------------- 3 files changed, 192 insertions(+), 168 deletions(-) create mode 100644 src/connector/src/source/cdc/external/maybe_tls_connector.rs diff --git a/src/connector/src/source/cdc/external/maybe_tls_connector.rs b/src/connector/src/source/cdc/external/maybe_tls_connector.rs new file mode 100644 index 0000000000000..add99716e20c6 --- /dev/null +++ b/src/connector/src/source/cdc/external/maybe_tls_connector.rs @@ -0,0 +1,182 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::task; +use std::error::Error; +use std::io; +use std::pin::Pin; +use std::task::Poll; + +use futures::{Future, FutureExt}; +use openssl::error::ErrorStack; +use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect}; +use tokio_postgres::NoTls; + +pub enum MaybeMakeTlsConnector { + NoTls(NoTls), + Tls(MakeTlsConnector), +} + +impl MakeTlsConnect for MaybeMakeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send, +{ + type Error = ErrorStack; + type Stream = MaybeTlsStream; + type TlsConnect = MaybeTlsConnector; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + match self { + MaybeMakeTlsConnector::NoTls(make_connector) => { + let connector = + >::make_tls_connect(make_connector, domain) + .expect("make NoTls connector always success"); + Ok(MaybeTlsConnector::NoTls(connector)) + } + MaybeMakeTlsConnector::Tls(make_connector) => { + >::make_tls_connect(make_connector, domain) + .map(MaybeTlsConnector::Tls) + } + } + } +} + +pub enum MaybeTlsConnector { + NoTls(NoTls), + Tls(TlsConnector), +} + +impl TlsConnect for MaybeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +{ + type Error = Box; + type Future = MaybeTlsFuture; + type Stream = MaybeTlsStream; + + fn connect(self, stream: S) -> Self::Future { + match self { + MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)), + MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin( + connector + .connect(stream) + .map(|x| x.map(|x| MaybeTlsStream::Tls(x))), + )), + } + } +} + +pub enum MaybeTlsStream { + NoTls(NoTlsStream), + Tls(TlsStream), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_read(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncRead>::poll_read(Pin::new(stream), cx, buf) + } + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_write(Pin::new(stream), cx, buf) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_write(Pin::new(stream), cx, buf) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_flush(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_flush(Pin::new(stream), cx) + } + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::NoTls(stream) => { + ::poll_shutdown(Pin::new(stream), cx) + } + MaybeTlsStream::Tls(stream) => { + as AsyncWrite>::poll_shutdown(Pin::new(stream), cx) + } + } + } +} + +impl tls::TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> tls::ChannelBinding { + match self { + MaybeTlsStream::NoTls(stream) => stream.channel_binding(), + MaybeTlsStream::Tls(stream) => stream.channel_binding(), + } + } +} + +pub enum MaybeTlsFuture { + NoTls(NoTlsFuture), + Tls(Pin> + Send>>), +} + +impl Future for MaybeTlsFuture, E> +where + MaybeTlsStream: Sync + Send, + E: std::convert::From, +{ + type Output = Result, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match &mut *self { + MaybeTlsFuture::NoTls(fut) => fut + .poll_unpin(cx) + .map(|x| x.map(|x| MaybeTlsStream::NoTls(x))) + .map_err(|x| x.into()), + MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx), + } + } +} diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index d6a6c9cf96aec..ed549fd2e5e28 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -15,6 +15,9 @@ pub mod mock_external_table; mod postgres; +#[cfg(not(madsim))] +mod maybe_tls_connector; + use std::collections::HashMap; use std::fmt; diff --git a/src/connector/src/source/cdc/external/postgres.rs b/src/connector/src/source/cdc/external/postgres.rs index e5b550a24fe9e..7660546af14cb 100644 --- a/src/connector/src/source/cdc/external/postgres.rs +++ b/src/connector/src/source/cdc/external/postgres.rs @@ -12,197 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::task; use std::cmp::Ordering; use std::collections::HashMap; -use std::error::Error; -use std::io; -use std::pin::Pin; -use std::task::Poll; use anyhow::Context; use futures::stream::BoxStream; -use futures::{pin_mut, Future, FutureExt, StreamExt}; +use futures::{pin_mut, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; -use openssl::error::ErrorStack; use openssl::ssl::{SslConnector, SslMethod}; -use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream}; +use postgres_openssl::MakeTlsConnector; use risingwave_common::catalog::Schema; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::DatumRef; use serde_derive::{Deserialize, Serialize}; use thiserror_ext::AsReport; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect}; use tokio_postgres::types::PgLsn; use tokio_postgres::NoTls; use crate::error::{ConnectorError, ConnectorResult}; use crate::parser::postgres_row_to_owned_row; +#[cfg(not(madsim))] +use crate::source::cdc::external::maybe_tls_connector::MaybeMakeTlsConnector; use crate::source::cdc::external::{ CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader, SchemaTableName, SslMode, }; -enum MaybeMakeTlsConnector { - NoTls(NoTls), - Tls(MakeTlsConnector), -} - -impl MakeTlsConnect for MaybeMakeTlsConnector -where - S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send, -{ - type Error = ErrorStack; - type Stream = MaybeTlsStream; - type TlsConnect = MaybeTlsConnector; - - fn make_tls_connect(&mut self, domain: &str) -> Result { - match self { - MaybeMakeTlsConnector::NoTls(make_connector) => { - let connector = - >::make_tls_connect(make_connector, domain) - .expect("make NoTls connector always success"); - Ok(MaybeTlsConnector::NoTls(connector)) - } - MaybeMakeTlsConnector::Tls(make_connector) => { - >::make_tls_connect(make_connector, domain) - .map(MaybeTlsConnector::Tls) - } - } - } -} - -enum MaybeTlsConnector { - NoTls(NoTls), - Tls(TlsConnector), -} - -impl TlsConnect for MaybeTlsConnector -where - S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, -{ - type Error = Box; - type Future = MaybeTlsFuture; - type Stream = MaybeTlsStream; - - fn connect(self, stream: S) -> Self::Future { - match self { - MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)), - MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin( - connector - .connect(stream) - .map(|x| x.map(|x| MaybeTlsStream::Tls(x))), - )), - } - } -} - -enum MaybeTlsStream { - NoTls(NoTlsStream), - Tls(TlsStream), -} - -impl AsyncRead for MaybeTlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_read(Pin::new(stream), cx, buf) - } - MaybeTlsStream::Tls(stream) => { - as AsyncRead>::poll_read(Pin::new(stream), cx, buf) - } - } - } -} - -impl AsyncWrite for MaybeTlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &[u8], - ) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_write(Pin::new(stream), cx, buf) - } - MaybeTlsStream::Tls(stream) => { - as AsyncWrite>::poll_write(Pin::new(stream), cx, buf) - } - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_flush(Pin::new(stream), cx) - } - MaybeTlsStream::Tls(stream) => { - as AsyncWrite>::poll_flush(Pin::new(stream), cx) - } - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::NoTls(stream) => { - ::poll_shutdown(Pin::new(stream), cx) - } - MaybeTlsStream::Tls(stream) => { - as AsyncWrite>::poll_shutdown(Pin::new(stream), cx) - } - } - } -} - -impl tls::TlsStream for MaybeTlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn channel_binding(&self) -> tls::ChannelBinding { - match self { - MaybeTlsStream::NoTls(stream) => stream.channel_binding(), - MaybeTlsStream::Tls(stream) => stream.channel_binding(), - } - } -} - -enum MaybeTlsFuture { - NoTls(NoTlsFuture), - Tls(Pin> + Send>>), -} - -impl Future for MaybeTlsFuture, E> -where - MaybeTlsStream: Sync + Send, - E: std::convert::From, -{ - type Output = Result, E>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - match &mut *self { - MaybeTlsFuture::NoTls(fut) => fut - .poll_unpin(cx) - .map(|x| x.map(|x| MaybeTlsStream::NoTls(x))) - .map_err(|x| x.into()), - MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx), - } - } -} - -// ======= - #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] pub struct PostgresOffset { pub txid: i64, @@ -306,6 +142,7 @@ impl PostgresExternalTableReader { config.sslmode ); + #[cfg(not(madsim))] let connector = match config.sslmode { SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls), SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) { @@ -320,6 +157,8 @@ impl PostgresExternalTableReader { MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())) } }; + #[cfg(madsim)] + let connector = NoTls; let (client, connection) = tokio_postgres::connect(&database_url, connector).await?;