From 3f49c6331eb4b83b0baec7f2a14699564e90de10 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Fri, 1 Mar 2024 21:34:07 -0500 Subject: [PATCH] feat(stream): Add mqtt connector --- Cargo.lock | 46 ++++++- ci/scripts/gen-integration-test-yaml.py | 1 + .../mqtt-source/create_source.sql | 11 ++ integration_tests/mqtt-source/data_check | 1 + .../mqtt-source/docker-compose.yml | 45 +++++++ integration_tests/mqtt-source/query.sql | 8 ++ src/connector/Cargo.toml | 4 + src/connector/src/common.rs | 126 ++++++++++++++++++ src/connector/src/error.rs | 2 + src/connector/src/macros.rs | 1 + src/connector/src/source/mod.rs | 2 + .../src/source/mqtt/enumerator/mod.rs | 102 ++++++++++++++ src/connector/src/source/mqtt/mod.rs | 55 ++++++++ .../src/source/mqtt/source/message.rs | 48 +++++++ src/connector/src/source/mqtt/source/mod.rs | 20 +++ .../src/source/mqtt/source/reader.rs | 102 ++++++++++++++ src/connector/src/source/mqtt/split.rs | 50 +++++++ src/frontend/src/handler/create_source.rs | 7 +- 18 files changed, 627 insertions(+), 4 deletions(-) create mode 100644 integration_tests/mqtt-source/create_source.sql create mode 100644 integration_tests/mqtt-source/data_check create mode 100644 integration_tests/mqtt-source/docker-compose.yml create mode 100644 integration_tests/mqtt-source/query.sql create mode 100644 src/connector/src/source/mqtt/enumerator/mod.rs create mode 100644 src/connector/src/source/mqtt/mod.rs create mode 100644 src/connector/src/source/mqtt/source/message.rs create mode 100644 src/connector/src/source/mqtt/source/mod.rs create mode 100644 src/connector/src/source/mqtt/source/reader.rs create mode 100644 src/connector/src/source/mqtt/split.rs diff --git a/Cargo.lock b/Cargo.lock index b42535e16b52a..b5b8f7d62c9fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -797,7 +797,7 @@ dependencies = [ "rustls", "rustls-native-certs", "rustls-pemfile", - "rustls-webpki", + "rustls-webpki 0.101.7", "serde", "serde_json", "serde_nanos", @@ -4064,6 +4064,7 @@ checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" dependencies = [ "futures-core", "futures-sink", + "nanorand", "pin-project", "spin 0.9.8", ] @@ -6360,6 +6361,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "034a0ad7deebf0c2abcf2435950a6666c3c15ea9d8fad0c0f48efa8a7f843fed" +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -9178,7 +9188,10 @@ dependencies = [ "risingwave_jni_core", "risingwave_pb", "risingwave_rpc_client", + "rumqttc", "rust_decimal", + "rustls-native-certs", + "rustls-pemfile", "rw_futures_util", "serde", "serde_derive", @@ -9195,6 +9208,7 @@ dependencies = [ "time", "tokio-postgres", "tokio-retry", + "tokio-rustls", "tokio-stream", "tokio-util", "tracing", @@ -10322,6 +10336,24 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rumqttc" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2433b134712bc17a6f85a35e06b901e6e8d0bb20b5367e1121e6fedc140c0ac" +dependencies = [ + "bytes", + "flume", + "futures", + "log", + "rustls-native-certs", + "rustls-pemfile", + "rustls-webpki 0.100.3", + "thiserror", + "tokio", + "tokio-rustls", +] + [[package]] name = "rust-embed" version = "8.1.0" @@ -10456,7 +10488,7 @@ checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" dependencies = [ "log", "ring 0.17.5", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] @@ -10481,6 +10513,16 @@ dependencies = [ "base64 0.21.4", ] +[[package]] +name = "rustls-webpki" +version = "0.100.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6a5fc258f1c1276dfe3016516945546e2d5383911efc0fc4f1cdc5df3a4ae3" +dependencies = [ + "ring 0.16.20", + "untrusted 0.7.1", +] + [[package]] name = "rustls-webpki" version = "0.101.7" diff --git a/ci/scripts/gen-integration-test-yaml.py b/ci/scripts/gen-integration-test-yaml.py index 8f39ab6edb180..eb439b10e352e 100644 --- a/ci/scripts/gen-integration-test-yaml.py +++ b/ci/scripts/gen-integration-test-yaml.py @@ -34,6 +34,7 @@ 'mindsdb': ['json'], 'vector': ['json'], 'nats': ['json'], + 'mqtt-source': ['json'], 'doris-sink': ['json'], 'starrocks-sink': ['json'], 'deltalake-sink': ['json'], diff --git a/integration_tests/mqtt-source/create_source.sql b/integration_tests/mqtt-source/create_source.sql new file mode 100644 index 0000000000000..6c6344c20cda7 --- /dev/null +++ b/integration_tests/mqtt-source/create_source.sql @@ -0,0 +1,11 @@ + +CREATE TABLE mqtt_source_table +( + id integer, + name varchar, +) +WITH ( + connector='mqtt', + host='mqtt-server', + topic= 'test' +) FORMAT PLAIN ENCODE JSON; diff --git a/integration_tests/mqtt-source/data_check b/integration_tests/mqtt-source/data_check new file mode 100644 index 0000000000000..fcaf3aca97ed0 --- /dev/null +++ b/integration_tests/mqtt-source/data_check @@ -0,0 +1 @@ +mqtt_source_table \ No newline at end of file diff --git a/integration_tests/mqtt-source/docker-compose.yml b/integration_tests/mqtt-source/docker-compose.yml new file mode 100644 index 0000000000000..87969f8ad9044 --- /dev/null +++ b/integration_tests/mqtt-source/docker-compose.yml @@ -0,0 +1,45 @@ +--- +version: "3" +services: + risingwave-standalone: + extends: + file: ../../docker/docker-compose.yml + service: risingwave-standalone + mqtt-server: + image: emqx/emqx:5.2.1 + ports: + - 1883:1883 + etcd-0: + extends: + file: ../../docker/docker-compose.yml + service: etcd-0 + grafana-0: + extends: + file: ../../docker/docker-compose.yml + service: grafana-0 + minio-0: + extends: + file: ../../docker/docker-compose.yml + service: minio-0 + prometheus-0: + extends: + file: ../../docker/docker-compose.yml + service: prometheus-0 + message_queue: + extends: + file: ../../docker/docker-compose.yml + service: message_queue +volumes: + compute-node-0: + external: false + etcd-0: + external: false + grafana-0: + external: false + minio-0: + external: false + prometheus-0: + external: false + message_queue: + external: false +name: risingwave-compose diff --git a/integration_tests/mqtt-source/query.sql b/integration_tests/mqtt-source/query.sql new file mode 100644 index 0000000000000..5a3abc3b555ce --- /dev/null +++ b/integration_tests/mqtt-source/query.sql @@ -0,0 +1,8 @@ +select + * +from + mqtt_source_table +order by + id +LIMIT + 10; \ No newline at end of file diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index c64f8dec8a1dd..164c2f52a66cf 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -113,7 +113,10 @@ risingwave_common = { workspace = true } risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } +rumqttc = "0.22.0" rust_decimal = "1" +rustls-native-certs = "0.6" +rustls-pemfile = "1" rw_futures_util = { workspace = true } serde = { version = "1", features = ["derive", "rc"] } serde_derive = "1" @@ -137,6 +140,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ ] } tokio-postgres = { version = "0.7", features = ["with-uuid-1"] } tokio-retry = "0.3" +tokio-rustls = "0.24" tokio-stream = "0.1" tokio-util = { version = "0.7", features = ["codec", "io"] } tonic = { workspace = true } diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index d5944eb07fa3c..15386fadeb250 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -684,3 +684,129 @@ impl NatsCommon { Ok(creds) } } + +#[serde_as] +#[derive(Deserialize, Debug, Clone, WithOptions)] +pub struct MqttCommon { + /// Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl + #[serde(rename = "protocol")] + pub protocol: Option, + #[serde(rename = "host")] + pub host: String, + #[serde(rename = "port")] + pub port: Option, + #[serde(rename = "topic")] + pub topic: String, + #[serde(rename = "username")] + pub user: Option, + #[serde(rename = "password")] + pub password: Option, + #[serde(rename = "client_prefix")] + pub client_prefix: Option, + #[serde(rename = "tls.ca")] + pub ca: Option, + #[serde(rename = "tls.client_cert")] + pub client_cert: Option, + #[serde(rename = "tls.client_key")] + pub client_key: Option, +} + +impl MqttCommon { + pub(crate) fn build_client( + &self, + id: u32, + ) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> { + let ssl = self + .protocol + .as_ref() + .map(|p| p == "ssl") + .unwrap_or_default(); + + let client_id = format!( + "{}_{}{}", + self.client_prefix.as_deref().unwrap_or("risingwave"), + id, + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() + % 100000, + ); + + let port = self.port.unwrap_or(if ssl { 8883 } else { 1883 }) as u16; + + let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port); + if ssl { + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + if let Some(ca) = &self.ca { + let certificates = load_certs(ca)?; + for cert in certificates { + root_cert_store.add(&cert).unwrap(); + } + } else { + for cert in + rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + root_cert_store + .add(&tokio_rustls::rustls::Certificate(cert.0)) + .unwrap(); + } + } + + let builder = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + let tls_config = if let (Some(client_cert), Some(client_key)) = + (self.client_cert.as_ref(), self.client_key.as_ref()) + { + let certs = load_certs(client_cert)?; + let key = load_private_key(client_key)?; + + builder.with_client_auth_cert(certs, key)? + } else { + builder.with_no_client_auth() + }; + + options.set_transport(rumqttc::Transport::tls_with_config( + rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)), + )); + } + + if let Some(user) = &self.user { + options.set_credentials(user, self.password.as_deref().unwrap_or_default()); + } + + Ok(rumqttc::v5::AsyncClient::new(options, 10)) + } +} + +fn load_certs(certificates: &str) -> ConnectorResult> { + let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") { + std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? + } else { + certificates.as_bytes().to_owned() + }; + + let certs = rustls_pemfile::certs(&mut cert_bytes.as_slice())?; + + Ok(certs + .into_iter() + .map(tokio_rustls::rustls::Certificate) + .collect()) +} + +fn load_private_key(certificate: &str) -> ConnectorResult { + let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") { + std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? + } else { + certificate.as_bytes().to_owned() + }; + + let certs = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice())?; + let cert = certs + .into_iter() + .next() + .ok_or_else(|| anyhow!("No private key found"))?; + Ok(tokio_rustls::rustls::PrivateKey(cert)) +} diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 3dc10af3d8e7a..1317981f88919 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -58,6 +58,8 @@ def_anyhow_newtype! { redis::RedisError => "Redis error", arrow_schema::ArrowError => "Arrow error", google_cloud_pubsub::client::google_cloud_auth::error::Error => "Google Cloud error", + tokio_rustls::rustls::Error => "TLS error", + rumqttc::v5::ClientError => "MQTT error", } pub type ConnectorResult = std::result::Result; diff --git a/src/connector/src/macros.rs b/src/connector/src/macros.rs index 4b375254c5ad1..b369e6d8a11e3 100644 --- a/src/connector/src/macros.rs +++ b/src/connector/src/macros.rs @@ -32,6 +32,7 @@ macro_rules! for_all_classified_sources { { Nexmark, $crate::source::nexmark::NexmarkProperties, $crate::source::nexmark::NexmarkSplit }, { Datagen, $crate::source::datagen::DatagenProperties, $crate::source::datagen::DatagenSplit }, { GooglePubsub, $crate::source::google_pubsub::PubsubProperties, $crate::source::google_pubsub::PubsubSplit }, + { Mqtt, $crate::source::mqtt::MqttProperties, $crate::source::mqtt::split::MqttSplit }, { Nats, $crate::source::nats::NatsProperties, $crate::source::nats::split::NatsSplit }, { S3, $crate::source::filesystem::S3Properties, $crate::source::filesystem::FsSplit }, { Gcs, $crate::source::filesystem::opendal_source::GcsProperties , $crate::source::filesystem::OpendalFsSplit<$crate::source::filesystem::opendal_source::OpendalGcs> }, diff --git a/src/connector/src/source/mod.rs b/src/connector/src/source/mod.rs index 3656820ed95b0..f965d373d9306 100644 --- a/src/connector/src/source/mod.rs +++ b/src/connector/src/source/mod.rs @@ -21,6 +21,7 @@ pub mod google_pubsub; pub mod kafka; pub mod kinesis; pub mod monitor; +pub mod mqtt; pub mod nats; pub mod nexmark; pub mod pulsar; @@ -29,6 +30,7 @@ pub(crate) use common::*; pub use google_pubsub::GOOGLE_PUBSUB_CONNECTOR; pub use kafka::KAFKA_CONNECTOR; pub use kinesis::KINESIS_CONNECTOR; +pub use mqtt::MQTT_CONNECTOR; pub use nats::NATS_CONNECTOR; mod common; pub mod iceberg; diff --git a/src/connector/src/source/mqtt/enumerator/mod.rs b/src/connector/src/source/mqtt/enumerator/mod.rs new file mode 100644 index 0000000000000..8e012b1c41036 --- /dev/null +++ b/src/connector/src/source/mqtt/enumerator/mod.rs @@ -0,0 +1,102 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashSet; + +use async_trait::async_trait; +use risingwave_common::bail; +use rumqttc::v5::{Event, Incoming}; +use rumqttc::Outgoing; + +use super::source::MqttSplit; +use super::MqttProperties; +use crate::error::ConnectorResult; +use crate::source::{SourceEnumeratorContextRef, SplitEnumerator}; + +pub struct MqttSplitEnumerator { + topic: String, + client: rumqttc::v5::AsyncClient, + eventloop: rumqttc::v5::EventLoop, +} + +#[async_trait] +impl SplitEnumerator for MqttSplitEnumerator { + type Properties = MqttProperties; + type Split = MqttSplit; + + async fn new( + properties: Self::Properties, + context: SourceEnumeratorContextRef, + ) -> ConnectorResult { + let (client, eventloop) = properties.common.build_client(context.info.source_id)?; + + Ok(Self { + topic: properties.common.topic, + client, + eventloop, + }) + } + + async fn list_splits(&mut self) -> ConnectorResult> { + if !self.topic.contains('#') && !self.topic.contains('+') { + self.client + .subscribe(self.topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .await?; + + let start = std::time::Instant::now(); + loop { + match self.eventloop.poll().await { + Ok(Event::Outgoing(Outgoing::Subscribe(_))) => { + break; + } + _ => { + if start.elapsed().as_secs() > 5 { + bail!("Failed to subscribe to topic {}", self.topic); + } + } + } + } + self.client.unsubscribe(self.topic.clone()).await?; + + return Ok(vec![MqttSplit::new(self.topic.clone())]); + } + + self.client + .subscribe(self.topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .await?; + + let start = std::time::Instant::now(); + let mut topics = HashSet::new(); + loop { + match self.eventloop.poll().await { + Ok(Event::Incoming(Incoming::Publish(p))) => { + topics.insert(String::from_utf8_lossy(&p.topic).to_string()); + } + _ => { + if start.elapsed().as_secs() > 15 { + self.client.unsubscribe(self.topic.clone()).await?; + if topics.is_empty() { + tracing::warn!( + "Failed to find any topics for pattern {}, using a single split", + self.topic + ); + return Ok(vec![MqttSplit::new(self.topic.clone())]); + } + return Ok(topics.into_iter().map(MqttSplit::new).collect()); + } + } + } + } + } +} diff --git a/src/connector/src/source/mqtt/mod.rs b/src/connector/src/source/mqtt/mod.rs new file mode 100644 index 0000000000000..c1085849255b4 --- /dev/null +++ b/src/connector/src/source/mqtt/mod.rs @@ -0,0 +1,55 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod enumerator; +pub mod source; +pub mod split; + +use std::collections::HashMap; + +use serde::Deserialize; +use with_options::WithOptions; + +use crate::common::MqttCommon; +use crate::source::mqtt::enumerator::MqttSplitEnumerator; +use crate::source::mqtt::source::{MqttSplit, MqttSplitReader}; +use crate::source::SourceProperties; + +pub const MQTT_CONNECTOR: &str = "mqtt"; + +#[derive(Clone, Debug, Deserialize, WithOptions)] +pub struct MqttProperties { + #[serde(flatten)] + pub common: MqttCommon, + + // 0 - AtLeastOnce, 1 - AtMostOnce, 2 - ExactlyOnce + pub qos: Option, + + #[serde(flatten)] + pub unknown_fields: HashMap, +} + +impl SourceProperties for MqttProperties { + type Split = MqttSplit; + type SplitEnumerator = MqttSplitEnumerator; + type SplitReader = MqttSplitReader; + + const SOURCE_NAME: &'static str = MQTT_CONNECTOR; +} + +impl crate::source::UnknownFields for MqttProperties { + fn unknown_fields(&self) -> HashMap { + self.unknown_fields.clone() + } +} diff --git a/src/connector/src/source/mqtt/source/message.rs b/src/connector/src/source/mqtt/source/message.rs new file mode 100644 index 0000000000000..16914a3dabcdb --- /dev/null +++ b/src/connector/src/source/mqtt/source/message.rs @@ -0,0 +1,48 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use rumqttc::v5::mqttbytes::v5::Publish; + +use crate::source::base::SourceMessage; +use crate::source::SourceMeta; + +#[derive(Clone, Debug)] +pub struct MqttMessage { + pub topic: String, + pub sequence_number: String, + pub payload: Vec, +} + +impl From for SourceMessage { + fn from(message: MqttMessage) -> Self { + SourceMessage { + key: None, + payload: Some(message.payload), + // For nats jetstream, use sequence id as offset + offset: message.sequence_number, + split_id: message.topic.into(), + meta: SourceMeta::Empty, + } + } +} + +impl MqttMessage { + pub fn new(message: Publish) -> Self { + MqttMessage { + topic: String::from_utf8_lossy(&message.topic).to_string(), + sequence_number: message.pkid.to_string(), + payload: message.payload.to_vec(), + } + } +} diff --git a/src/connector/src/source/mqtt/source/mod.rs b/src/connector/src/source/mqtt/source/mod.rs new file mode 100644 index 0000000000000..2cf5350a0247c --- /dev/null +++ b/src/connector/src/source/mqtt/source/mod.rs @@ -0,0 +1,20 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod message; +mod reader; + +pub use reader::*; + +pub use crate::source::mqtt::split::*; diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs new file mode 100644 index 0000000000000..e8e6ea3228b89 --- /dev/null +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -0,0 +1,102 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use futures_async_stream::try_stream; +use risingwave_common::bail; +use rumqttc::v5::mqttbytes::v5::Filter; +use rumqttc::v5::mqttbytes::QoS; +use rumqttc::v5::{Event, Incoming}; + +use super::message::MqttMessage; +use super::MqttSplit; +use crate::error::ConnectorResult as Result; +use crate::parser::ParserConfig; +use crate::source::common::{into_chunk_stream, CommonSplitReader}; +use crate::source::mqtt::MqttProperties; +use crate::source::{BoxChunkSourceStream, Column, SourceContextRef, SourceMessage, SplitReader}; + +pub struct MqttSplitReader { + eventloop: rumqttc::v5::EventLoop, + properties: MqttProperties, + parser_config: ParserConfig, + source_ctx: SourceContextRef, +} + +#[async_trait] +impl SplitReader for MqttSplitReader { + type Properties = MqttProperties; + type Split = MqttSplit; + + async fn new( + properties: MqttProperties, + splits: Vec, + parser_config: ParserConfig, + source_ctx: SourceContextRef, + _columns: Option>, + ) -> Result { + let (client, eventloop) = properties + .common + .build_client(source_ctx.source_info.fragment_id)?; + + let qos = if let Some(qos) = properties.qos { + match qos { + 0 => QoS::AtMostOnce, + 1 => QoS::AtLeastOnce, + 2 => QoS::ExactlyOnce, + _ => bail!("Invalid QoS level: {}", qos), + } + } else { + QoS::AtLeastOnce + }; + + client + .subscribe_many( + splits + .into_iter() + .map(|split| Filter::new(split.topic, qos)), + ) + .await?; + + Ok(Self { + eventloop, + properties, + parser_config, + source_ctx, + }) + } + + fn into_stream(self) -> BoxChunkSourceStream { + let parser_config = self.parser_config.clone(); + let source_context = self.source_ctx.clone(); + into_chunk_stream(self, parser_config, source_context) + } +} + +impl CommonSplitReader for MqttSplitReader { + #[try_stream(ok = Vec, error = crate::error::ConnectorError)] + async fn into_data_stream(self) { + let mut eventloop = self.eventloop; + loop { + match eventloop.poll().await { + Ok(Event::Incoming(Incoming::Publish(p))) => { + let msg = MqttMessage::new(p); + yield vec![SourceMessage::from(msg)]; + } + Ok(_) => (), + Err(err) => bail!("Error polling mqtt event loop: {}", err), + } + } + } +} diff --git a/src/connector/src/source/mqtt/split.rs b/src/connector/src/source/mqtt/split.rs new file mode 100644 index 0000000000000..b86bde6097ae9 --- /dev/null +++ b/src/connector/src/source/mqtt/split.rs @@ -0,0 +1,50 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::types::JsonbVal; +use serde::{Deserialize, Serialize}; + +use crate::error::ConnectorResult; +use crate::source::{SplitId, SplitMetaData}; + +/// The states of a NATS split, which will be persisted to checkpoint. +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] +pub struct MqttSplit { + pub(crate) topic: String, +} + +impl SplitMetaData for MqttSplit { + fn id(&self) -> SplitId { + // TODO: should avoid constructing a string every time + self.topic.clone().into() + } + + fn restore_from_json(value: JsonbVal) -> ConnectorResult { + serde_json::from_value(value.take()).map_err(Into::into) + } + + fn encode_to_json(&self) -> JsonbVal { + serde_json::to_value(self.clone()).unwrap().into() + } + + fn update_with_offset(&mut self, _start_sequence: String) -> ConnectorResult<()> { + Ok(()) + } +} + +impl MqttSplit { + pub fn new(topic: String) -> Self { + Self { topic } + } +} diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 364ae3cb80d91..fd5db2d2be3cb 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -48,8 +48,8 @@ use risingwave_connector::source::nexmark::source::{get_event_data_types_with_na use risingwave_connector::source::test_source::TEST_CONNECTOR; use risingwave_connector::source::{ ConnectorProperties, GCS_CONNECTOR, GOOGLE_PUBSUB_CONNECTOR, KAFKA_CONNECTOR, - KINESIS_CONNECTOR, NATS_CONNECTOR, NEXMARK_CONNECTOR, OPENDAL_S3_CONNECTOR, POSIX_FS_CONNECTOR, - PULSAR_CONNECTOR, S3_CONNECTOR, + KINESIS_CONNECTOR, MQTT_CONNECTOR, NATS_CONNECTOR, NEXMARK_CONNECTOR, OPENDAL_S3_CONNECTOR, + POSIX_FS_CONNECTOR, PULSAR_CONNECTOR, S3_CONNECTOR, }; use risingwave_pb::catalog::{ PbSchemaRegistryNameStrategy, PbSource, StreamSourceInfo, WatermarkDesc, @@ -994,6 +994,9 @@ static CONNECTORS_COMPATIBLE_FORMATS: LazyLock hashmap!( Format::Plain => vec![Encode::Json, Encode::Protobuf], ), + MQTT_CONNECTOR => hashmap!( + Format::Plain => vec![Encode::Json, Encode::Bytes], + ), TEST_CONNECTOR => hashmap!( Format::Plain => vec![Encode::Json], ),