Skip to content

Commit

Permalink
feat: Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bakjos committed Mar 5, 2024
1 parent 35836a1 commit f862c7f
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 135 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions integration_tests/mqtt/create_source.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ CREATE TABLE mqtt_source_table
)
WITH (
connector='mqtt',
host='mqtt-server',
url='tcp://mqtt-server',
topic= 'test',
qos = 'at_least_once',
) FORMAT PLAIN ENCODE JSON;
Expand All @@ -20,7 +20,7 @@ FROM
WITH
(
connector='mqtt',
host='mqtt-server',
url='tcp://mqtt-server',
topic= 'test',
type = 'append-only',
retain = 'false',
Expand Down
2 changes: 1 addition & 1 deletion src/connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ risingwave_common = { workspace = true }
risingwave_jni_core = { workspace = true }
risingwave_pb = { workspace = true }
risingwave_rpc_client = { workspace = true }
rumqttc = "0.22.0"
rumqttc = { version = "0.22.0", features = ["url"] }
rust_decimal = "1"
rustls-native-certs = "0.6"
rustls-pemfile = "1"
Expand Down
132 changes: 70 additions & 62 deletions src/connector/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,29 +695,23 @@ pub enum QualityOfService {
ExactlyOnce,
}

#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
#[strum(serialize_all = "snake_case")]
pub enum Protocol {
Tcp,
Ssl,
}

#[serde_as]
#[derive(Deserialize, Debug, Clone, WithOptions)]
pub struct MqttCommon {
/// Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp`
#[serde_as(as = "Option<DisplayFromStr>")]
pub protocol: Option<Protocol>,

/// Hostname of the mqtt broker
pub host: String,

/// Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl
pub port: Option<i32>,
/// The url of the broker to connect to. e.g. tcp://localhost.
/// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`,
/// `ws://` or `wss://` to denote the protocol for establishing a connection with the broker.
/// `mqtts://`, `ssl://`, `wss://`
pub url: String,

/// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
pub topic: String,

/// The quality of service to use when publishing messages. Defaults to at_most_once.
/// Could be at_most_once, at_least_once or exactly_once
#[serde_as(as = "Option<DisplayFromStr>")]
pub qos: Option<QualityOfService>,

/// Username for the mqtt broker
#[serde(rename = "username")]
pub user: Option<String>,
Expand Down Expand Up @@ -759,64 +753,32 @@ pub struct MqttCommon {
impl MqttCommon {
pub(crate) fn build_client(
&self,
actor_id: u32,
id: u32,
) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> {
let ssl = self
.protocol
.as_ref()
.map(|p| p == &Protocol::Ssl)
.unwrap_or_default();

let client_id = format!(
"{}_{}{}",
"{}_{}_{}",
self.client_prefix.as_deref().unwrap_or("risingwave"),
id,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
% 100000,
actor_id,
id
);

let port = self.port.unwrap_or(if ssl { 8883 } else { 1883 }) as u16;
let mut url = url::Url::parse(&self.url)?;

let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port);
let ssl = match url.scheme() {
"mqtts" | "ssl" | "wss" => true,
_ => false,
};

url.query_pairs_mut().append_pair("client_id", &client_id);

let mut options = rumqttc::v5::MqttOptions::try_from(url)?;
options.set_keep_alive(std::time::Duration::from_secs(10));

options.set_clean_start(self.clean_start);

if ssl {
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
if let Some(ca) = &self.ca {
let certificates = load_certs(ca)?;
for cert in certificates {
root_cert_store.add(&cert).unwrap();
}
} else {
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store
.add(&tokio_rustls::rustls::Certificate(cert.0))
.unwrap();
}
}

let builder = tokio_rustls::rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store);

let tls_config = if let (Some(client_cert), Some(client_key)) =
(self.client_cert.as_ref(), self.client_key.as_ref())
{
let certs = load_certs(client_cert)?;
let key = load_private_key(client_key)?;

builder.with_client_auth_cert(certs, key)?
} else {
builder.with_no_client_auth()
};

let tls_config = self.get_tls_config()?;
options.set_transport(rumqttc::Transport::tls_with_config(
rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)),
));
Expand All @@ -831,6 +793,52 @@ impl MqttCommon {
self.inflight_messages.unwrap_or(100),
))
}

pub(crate) fn qos(&self) -> rumqttc::v5::mqttbytes::QoS {
self.qos
.as_ref()
.map(|qos| match qos {
QualityOfService::AtMostOnce => rumqttc::v5::mqttbytes::QoS::AtMostOnce,
QualityOfService::AtLeastOnce => rumqttc::v5::mqttbytes::QoS::AtLeastOnce,
QualityOfService::ExactlyOnce => rumqttc::v5::mqttbytes::QoS::ExactlyOnce,
})
.unwrap_or(rumqttc::v5::mqttbytes::QoS::AtMostOnce)
}

fn get_tls_config(&self) -> ConnectorResult<tokio_rustls::rustls::ClientConfig> {
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
if let Some(ca) = &self.ca {
let certificates = load_certs(ca)?;
for cert in certificates {
root_cert_store.add(&cert).unwrap();
}
} else {
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store
.add(&tokio_rustls::rustls::Certificate(cert.0))
.unwrap();
}
}

let builder = tokio_rustls::rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store);

let tls_config = if let (Some(client_cert), Some(client_key)) =
(self.client_cert.as_ref(), self.client_key.as_ref())
{
let certs = load_certs(client_cert)?;
let key = load_private_key(client_key)?;

builder.with_client_auth_cert(certs, key)?
} else {
builder.with_no_client_auth()
};

Ok(tls_config)
}
}

fn load_certs(certificates: &str) -> ConnectorResult<Vec<tokio_rustls::rustls::Certificate>> {
Expand Down
1 change: 1 addition & 0 deletions src/connector/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def_anyhow_newtype! {
google_cloud_pubsub::client::google_cloud_auth::error::Error => "Google Cloud error",
tokio_rustls::rustls::Error => "TLS error",
rumqttc::v5::ClientError => "MQTT error",
rumqttc::v5::OptionError => "MQTT error",
}

pub type ConnectorResult<T, E = ConnectorError> = std::result::Result<T, E>;
Expand Down
48 changes: 17 additions & 31 deletions src/connector/src/sink/mqtt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ use risingwave_common::catalog::Schema;
use rumqttc::v5::mqttbytes::QoS;
use rumqttc::v5::ConnectionError;
use serde_derive::Deserialize;
use serde_with::{serde_as, DisplayFromStr};
use serde_with::serde_as;
use thiserror_ext::AsReport;
use with_options::WithOptions;

use super::catalog::SinkFormatDesc;
use super::formatter::SinkFormatterImpl;
use super::writer::FormattedSink;
use super::{DummySinkCommitCoordinator, SinkWriterParam};
use crate::common::{MqttCommon, QualityOfService};
use crate::common::MqttCommon;
use crate::sink::catalog::desc::SinkDesc;
use crate::sink::log_store::DeliveryFutureManagerAddFuture;
use crate::sink::writer::{
Expand All @@ -47,11 +47,6 @@ pub struct MqttConfig {
#[serde(flatten)]
pub common: MqttCommon,

/// The quality of service to use when publishing messages. Defaults to at_most_once.
/// Could be at_most_once, at_least_once or exactly_once
#[serde_as(as = "Option<DisplayFromStr>")]
pub qos: Option<QualityOfService>,

/// Whether the message should be retained by the broker
#[serde(default, deserialize_with = "deserialize_bool_from_string")]
pub retain: bool,
Expand Down Expand Up @@ -132,7 +127,7 @@ impl Sink for MqttSink {
)));
}

let _client = (self.config.common.build_client(0))
let _client = (self.config.common.build_client(0, 0))
.context("validate mqtt sink error")
.map_err(SinkError::Mqtt)?;

Expand Down Expand Up @@ -174,19 +169,11 @@ impl MqttSinkWriter {
)
.await?;

let qos = config
.qos
.as_ref()
.map(|qos| match qos {
QualityOfService::AtMostOnce => QoS::AtMostOnce,
QualityOfService::AtLeastOnce => QoS::AtLeastOnce,
QualityOfService::ExactlyOnce => QoS::ExactlyOnce,
})
.unwrap_or(QoS::AtMostOnce);
let qos = config.common.qos();

let (client, mut eventloop) = config
.common
.build_client(id as u32)
.build_client(0, id as u32)
.map_err(|e| SinkError::Mqtt(anyhow!(e)))?;

let stopped = Arc::new(AtomicBool::new(false));
Expand All @@ -196,24 +183,23 @@ impl MqttSinkWriter {
while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) {
match eventloop.poll().await {
Ok(_) => (),
Err(err) => {
if let ConnectionError::Timeout(_) = err {
Err(err) => match err {
ConnectionError::Timeout(_) => {
continue;
}

if let ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) = err {
if err.kind() != std::io::ErrorKind::ConnectionAborted {
tracing::error!(
"Failed to poll mqtt eventloop: {}",
err.as_report()
);
std::thread::sleep(std::time::Duration::from_secs(1));
}
} else {
ConnectionError::MqttState(rumqttc::v5::StateError::Io(err))
| ConnectionError::Io(err)
if err.kind() == std::io::ErrorKind::ConnectionAborted
|| err.kind() == std::io::ErrorKind::ConnectionReset =>
{
continue;
}
err => {
println!("Err: {:?}", err);
tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report());
std::thread::sleep(std::time::Duration::from_secs(1));
}
}
},
}
}
});
Expand Down
6 changes: 3 additions & 3 deletions src/connector/src/source/mqtt/enumerator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl SplitEnumerator for MqttSplitEnumerator {
properties: Self::Properties,
context: SourceEnumeratorContextRef,
) -> ConnectorResult<MqttSplitEnumerator> {
let (client, mut eventloop) = properties.common.build_client(context.info.source_id)?;
let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?;

let topic = properties.common.topic.clone();
let mut topics = HashSet::new();
Expand Down Expand Up @@ -92,7 +92,7 @@ impl SplitEnumerator for MqttSplitEnumerator {
continue;
}
tracing::error!(
"[Enumerator] Failed to subscribe to topic {}: {}",
"Failed to subscribe to topic {}: {}",
topic,
err.as_report(),
);
Expand Down Expand Up @@ -127,7 +127,7 @@ impl SplitEnumerator for MqttSplitEnumerator {
bail!("Failed to connect to mqtt broker");
}

tokio::time::sleep(std::time::Duration::from_millis(100)).await;
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
}

Expand Down
13 changes: 2 additions & 11 deletions src/connector/src/source/mqtt/source/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use thiserror_ext::AsReport;

use super::message::MqttMessage;
use super::MqttSplit;
use crate::common::QualityOfService;
use crate::error::ConnectorResult as Result;
use crate::parser::ParserConfig;
use crate::source::common::{into_chunk_stream, CommonSplitReader};
Expand Down Expand Up @@ -52,17 +51,9 @@ impl SplitReader for MqttSplitReader {
) -> Result<Self> {
let (client, eventloop) = properties
.common
.build_client(source_ctx.source_info.fragment_id)?;
.build_client(source_ctx.source_info.actor_id, source_ctx.source_info.fragment_id)?;

let qos = properties
.qos
.as_ref()
.map(|qos| match qos {
QualityOfService::AtMostOnce => QoS::AtMostOnce,
QualityOfService::AtLeastOnce => QoS::AtLeastOnce,
QualityOfService::ExactlyOnce => QoS::ExactlyOnce,
})
.unwrap_or(QoS::AtMostOnce);
let qos = properties.common.qos();

client
.subscribe_many(
Expand Down
1 change: 0 additions & 1 deletion src/connector/src/with_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ impl WithOptions for i64 {}
impl WithOptions for f64 {}
impl WithOptions for std::time::Duration {}
impl WithOptions for crate::common::QualityOfService {}
impl WithOptions for crate::common::Protocol {}
impl WithOptions for crate::sink::kafka::CompressionCodec {}
impl WithOptions for nexmark::config::RateShape {}
impl WithOptions for nexmark::event::EventType {}
Loading

0 comments on commit f862c7f

Please sign in to comment.