Skip to content

Commit

Permalink
feat: Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bakjos committed Mar 6, 2024
1 parent 8c9a8db commit 6d7cbdb
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 196 deletions.
1 change: 1 addition & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ steam = "stream" # You played with Steam games too much.
# Some weird short variable names
ot = "ot"
bui = "bui"
mosquitto = "mosquitto" # This is a MQTT broker.

[default.extend-identifiers]

Expand Down
28 changes: 28 additions & 0 deletions integration_tests/mqtt/create_sink.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
CREATE SINK mqtt_sink
FROM
personnel
WITH
(
connector='mqtt',
url='tcp://mqtt-server',
topic= 'test',
type = 'append-only',
retain = 'true',
qos = 'at_least_once',
) FORMAT PLAIN ENCODE JSON (
force_append_only='true',
);

INSERT INTO
personnel
VALUES
(1, 'Alice'),
(2, 'Bob'),
(3, 'Tom'),
(4, 'Jerry'),
(5, 'Araminta'),
(6, 'Clover'),
(7, 'Posey'),
(8, 'Waverly');

FLUSH;
30 changes: 0 additions & 30 deletions integration_tests/mqtt/create_source.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,3 @@ WITH (
topic= 'test',
qos = 'at_least_once',
) FORMAT PLAIN ENCODE JSON;


CREATE SINK mqtt_sink
FROM
personnel
WITH
(
connector='mqtt',
url='tcp://mqtt-server',
topic= 'test',
type = 'append-only',
retain = 'false',
qos = 'at_least_once',
) FORMAT PLAIN ENCODE JSON (
force_append_only='true',
);

INSERT INTO
personnel
VALUES
(1, 'Alice'),
(2, 'Bob'),
(3, 'Tom'),
(4, 'Jerry'),
(5, 'Araminta'),
(6, 'Clover'),
(7, 'Posey'),
(8, 'Waverly');

FLUSH;
6 changes: 5 additions & 1 deletion integration_tests/mqtt/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ services:
file: ../../docker/docker-compose.yml
service: risingwave-standalone
mqtt-server:
image: emqx/emqx:5.2.1
image: eclipse-mosquitto
command:
- sh
- -c
- echo "running command"; printf 'allow_anonymous true\nlistener 1883 0.0.0.0' > /mosquitto/config/mosquitto.conf; echo "starting service..."; cat /mosquitto/config/mosquitto.conf;/docker-entrypoint.sh;/usr/sbin/mosquitto -c /mosquitto/config/mosquitto.conf
ports:
- 1883:1883
etcd-0:
Expand Down
14 changes: 14 additions & 0 deletions integration_tests/mqtt/sink_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys
import subprocess


output = subprocess.Popen(["docker", "compose", "exec", "mqtt-server", "mosquitto_sub", "-h", "localhost", "-t", "test", "-p", "1883", "-C", "1", "-W", "120"],
stdout=subprocess.PIPE)
rows = subprocess.check_output(["wc", "-l"], stdin=output.stdout)
output.stdout.close()
output.wait()
rows = int(rows.decode('utf8').strip())
print(f"{rows} rows in 'test'")
if rows < 1:
print(f"Data check failed for case 'test'")
sys.exit(1)
165 changes: 7 additions & 158 deletions src/connector/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ use risingwave_common::bail;
use serde_derive::Deserialize;
use serde_with::json::JsonString;
use serde_with::{serde_as, DisplayFromStr};
use strum_macros::{Display, EnumString};
use tempfile::NamedTempFile;
use time::OffsetDateTime;
use url::Url;
use with_options::WithOptions;

use crate::aws_utils::load_file_descriptor_from_s3;
use crate::deserialize_duration_from_string;
use crate::error::ConnectorResult;
use crate::sink::SinkError;
use crate::source::nats::source::NatsOffset;
use crate::{deserialize_bool_from_string, deserialize_duration_from_string};
// The file describes the common abstractions for each connector and can be used in both source and
// sink.

Expand Down Expand Up @@ -686,161 +685,9 @@ impl NatsCommon {
}
}

#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
#[strum(serialize_all = "snake_case")]
#[allow(clippy::enum_variant_names)]
pub enum QualityOfService {
AtLeastOnce,
AtMostOnce,
ExactlyOnce,
}

#[serde_as]
#[derive(Deserialize, Debug, Clone, WithOptions)]
pub struct MqttCommon {
/// The url of the broker to connect to. e.g. tcp://localhost.
/// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`,
/// to denote the protocol for establishing a connection with the broker.
/// `mqtts://`, `ssl://` will use the native certificates if no ca is specified
pub url: String,

/// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
pub topic: String,

/// The quality of service to use when publishing messages. Defaults to at_most_once.
/// Could be at_most_once, at_least_once or exactly_once
#[serde_as(as = "Option<DisplayFromStr>")]
pub qos: Option<QualityOfService>,

/// Username for the mqtt broker
#[serde(rename = "username")]
pub user: Option<String>,

/// Password for the mqtt broker
pub password: Option<String>,

/// Prefix for the mqtt client id.
/// The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave
pub client_prefix: Option<String>,

/// `clean_start = true` removes all the state from queues & instructs the broker
/// to clean all the client state when client disconnects.
///
/// When set `false`, broker will hold the client state and performs pending
/// operations on the client when reconnection with same `client_id`
/// happens. Local queue state is also held to retransmit packets after reconnection.
#[serde(default, deserialize_with = "deserialize_bool_from_string")]
pub clean_start: bool,

/// The maximum number of inflight messages. Defaults to 100
#[serde_as(as = "Option<DisplayFromStr>")]
pub inflight_messages: Option<usize>,

/// Path to CA certificate file for verifying the broker's key.
#[serde(rename = "tls.client_key")]
pub ca: Option<String>,
/// Path to client's certificate file (PEM). Required for client authentication.
/// Can be a file path under fs:// or a string with the certificate content.
#[serde(rename = "tls.client_cert")]
pub client_cert: Option<String>,

/// Path to client's private key file (PEM). Required for client authentication.
/// Can be a file path under fs:// or a string with the private key content.
#[serde(rename = "tls.client_key")]
pub client_key: Option<String>,
}

impl MqttCommon {
pub(crate) fn build_client(
&self,
actor_id: u32,
id: u64,
) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> {
let client_id = format!(
"{}_{}_{}",
self.client_prefix.as_deref().unwrap_or("risingwave"),
actor_id,
id
);

let mut url = url::Url::parse(&self.url)?;

let ssl = matches!(url.scheme(), "mqtts" | "ssl");

url.query_pairs_mut().append_pair("client_id", &client_id);

tracing::debug!("connecting mqtt using url: {}", url.as_str());

let mut options = rumqttc::v5::MqttOptions::try_from(url)?;
options.set_keep_alive(std::time::Duration::from_secs(10));

options.set_clean_start(self.clean_start);

if ssl {
let tls_config = self.get_tls_config()?;
options.set_transport(rumqttc::Transport::tls_with_config(
rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)),
));
}

if let Some(user) = &self.user {
options.set_credentials(user, self.password.as_deref().unwrap_or_default());
}

Ok(rumqttc::v5::AsyncClient::new(
options,
self.inflight_messages.unwrap_or(100),
))
}

pub(crate) fn qos(&self) -> rumqttc::v5::mqttbytes::QoS {
self.qos
.as_ref()
.map(|qos| match qos {
QualityOfService::AtMostOnce => rumqttc::v5::mqttbytes::QoS::AtMostOnce,
QualityOfService::AtLeastOnce => rumqttc::v5::mqttbytes::QoS::AtLeastOnce,
QualityOfService::ExactlyOnce => rumqttc::v5::mqttbytes::QoS::ExactlyOnce,
})
.unwrap_or(rumqttc::v5::mqttbytes::QoS::AtMostOnce)
}

fn get_tls_config(&self) -> ConnectorResult<tokio_rustls::rustls::ClientConfig> {
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
if let Some(ca) = &self.ca {
let certificates = load_certs(ca)?;
for cert in certificates {
root_cert_store.add(&cert).unwrap();
}
} else {
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store
.add(&tokio_rustls::rustls::Certificate(cert.0))
.unwrap();
}
}

let builder = tokio_rustls::rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store);

let tls_config = if let (Some(client_cert), Some(client_key)) =
(self.client_cert.as_ref(), self.client_key.as_ref())
{
let certs = load_certs(client_cert)?;
let key = load_private_key(client_key)?;

builder.with_client_auth_cert(certs, key)?
} else {
builder.with_no_client_auth()
};

Ok(tls_config)
}
}

fn load_certs(certificates: &str) -> ConnectorResult<Vec<tokio_rustls::rustls::Certificate>> {
pub(crate) fn load_certs(
certificates: &str,
) -> ConnectorResult<Vec<tokio_rustls::rustls::Certificate>> {
let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") {
std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
} else {
Expand All @@ -855,7 +702,9 @@ fn load_certs(certificates: &str) -> ConnectorResult<Vec<tokio_rustls::rustls::C
.collect())
}

fn load_private_key(certificate: &str) -> ConnectorResult<tokio_rustls::rustls::PrivateKey> {
pub(crate) fn load_private_key(
certificate: &str,
) -> ConnectorResult<tokio_rustls::rustls::PrivateKey> {
let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") {
std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())?
} else {
Expand Down
1 change: 1 addition & 0 deletions src/connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub mod sink;
pub mod source;

pub mod common;
pub mod mqtt_common;

pub use paste::paste;

Expand Down
Loading

0 comments on commit 6d7cbdb

Please sign in to comment.