Skip to content

Commit

Permalink
fix: add MaybeMakeTlsConnector to support no tls
Browse files Browse the repository at this point in the history
only when sslmode is disable, we should use NoTls, and we should use Tls
when sslmode is prefer or require.
  • Loading branch information
jetjinser committed Mar 19, 2024
1 parent 9cf2811 commit 46a9074
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

188 changes: 183 additions & 5 deletions src/connector/src/source/cdc/external/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,196 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use core::task;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::error::Error;
use std::io;
use std::pin::Pin;
use std::task::Poll;

use anyhow::Context;
use futures::stream::BoxStream;
use futures::{pin_mut, StreamExt};
use futures::{pin_mut, Future, FutureExt, StreamExt};
use futures_async_stream::try_stream;
use itertools::Itertools;
use openssl::error::ErrorStack;
use openssl::ssl::{SslConnector, SslMethod};
use postgres_openssl::MakeTlsConnector;
use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream};
use risingwave_common::catalog::Schema;
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::types::DatumRef;
use serde_derive::{Deserialize, Serialize};
use thiserror_ext::AsReport;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect};
use tokio_postgres::types::PgLsn;
use tokio_postgres::NoTls;

use crate::error::{ConnectorError, ConnectorResult};
use crate::parser::postgres_row_to_owned_row;
use crate::source::cdc::external::{
CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader,
SchemaTableName,
SchemaTableName, SslMode,
};

enum MaybeMakeTlsConnector {
NoTls(NoTls),
Tls(MakeTlsConnector),
}

impl<S> MakeTlsConnect<S> for MaybeMakeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send,
{
type Error = ErrorStack;
type Stream = MaybeTlsStream<S>;
type TlsConnect = MaybeTlsConnector;

fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
match self {
MaybeMakeTlsConnector::NoTls(make_connector) => {
let connector =
<NoTls as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
.expect("make NoTls connector always success");
Ok(MaybeTlsConnector::NoTls(connector))
}
MaybeMakeTlsConnector::Tls(make_connector) => {
<MakeTlsConnector as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
.map(MaybeTlsConnector::Tls)
}
}
}
}

enum MaybeTlsConnector {
NoTls(NoTls),
Tls(TlsConnector),
}

impl<S> TlsConnect<S> for MaybeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Error = Box<dyn Error + Send + Sync>;
type Future = MaybeTlsFuture<MaybeTlsStream<S>, Self::Error>;
type Stream = MaybeTlsStream<S>;

fn connect(self, stream: S) -> Self::Future {
match self {
MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)),
MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin(
connector
.connect(stream)
.map(|x| x.map(|x| MaybeTlsStream::Tls(x))),
)),
}
}
}

enum MaybeTlsStream<S> {
NoTls(NoTlsStream),
Tls(TlsStream<S>),
}

impl<S> AsyncRead for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
}
}
}
}

impl<S> AsyncWrite for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_flush(Pin::new(stream), cx)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_flush(Pin::new(stream), cx)
}
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::NoTls(stream) => {
<NoTlsStream as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
}
MaybeTlsStream::Tls(stream) => {
<TlsStream<S> as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
}
}
}
}

impl<S> tls::TlsStream for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> tls::ChannelBinding {
match self {
MaybeTlsStream::NoTls(stream) => stream.channel_binding(),
MaybeTlsStream::Tls(stream) => stream.channel_binding(),
}
}
}

enum MaybeTlsFuture<S, E> {
NoTls(NoTlsFuture),
Tls(Pin<Box<dyn Future<Output = Result<S, E>> + Send>>),
}

impl<S, E> Future for MaybeTlsFuture<MaybeTlsStream<S>, E>
where
E: std::convert::From<tokio_postgres::tls::NoTlsError>,
{
type Output = Result<MaybeTlsStream<S>, E>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match &mut *self {
MaybeTlsFuture::NoTls(fut) => fut
.poll_unpin(cx)
.map(|x| x.map(|x| MaybeTlsStream::NoTls(x)))
.map_err(|x| x.into()),
MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx),
}
}
}

// =======

#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct PostgresOffset {
pub txid: i64,
Expand Down Expand Up @@ -139,8 +305,20 @@ impl PostgresExternalTableReader {
config.sslmode
);

let builder = SslConnector::builder(SslMethod::tls())?;
let connector = MakeTlsConnector::new(builder.build());
let connector = match config.sslmode {
SslMode::Disable => MaybeMakeTlsConnector::NoTls(NoTls),
SslMode::Prefer => match SslConnector::builder(SslMethod::tls()) {
Ok(builder) => MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build())),
Err(e) => {
tracing::warn!(error = %e.as_report(), "SSL connector error");
MaybeMakeTlsConnector::NoTls(NoTls)
}
},
SslMode::Require => {
let builder = SslConnector::builder(SslMethod::tls())?;
MaybeMakeTlsConnector::Tls(MakeTlsConnector::new(builder.build()))
}
};

let (client, connection) = tokio_postgres::connect(&database_url, connector).await?;

Expand Down

0 comments on commit 46a9074

Please sign in to comment.