From 6b084bc9b8698730d7c0936f03b2623c9a3065db Mon Sep 17 00:00:00 2001
From: twosom <72733442+twosom@users.noreply.github.com>
Date: Fri, 11 Oct 2024 00:28:27 +0900
Subject: [PATCH] Add support for Read with Meatadata in `MqttIO` (#32668)
* add support for read with metadata in MqttIO
* Update CHANGES.md
* update javadoc
* update javadoc
* refactor : change to use SchemaCoder in MqttIO
- remove MqttRecordCoder
- refactor MqttRecord to use AutoValueSchema
- change related test
---
CHANGES.md | 1 +
.../org/apache/beam/sdk/io/mqtt/MqttIO.java | 163 ++++++++++++++----
.../apache/beam/sdk/io/mqtt/MqttRecord.java | 49 ++++++
.../apache/beam/sdk/io/mqtt/MqttIOTest.java | 78 ++++++++-
4 files changed, 257 insertions(+), 34 deletions(-)
create mode 100644 sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java
diff --git a/CHANGES.md b/CHANGES.md
index 6a70a49b2ab1..39bad44dc52c 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -68,6 +68,7 @@
## New Features / Improvements
+* Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195))
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Added support for processing events which use a global sequence to "ordered" extension (Java) [#32540](https://github.com/apache/beam/pull/32540)
diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
index e1868e2c8461..efc51362d06a 100644
--- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
+++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
@@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.mqtt;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import com.google.auto.value.AutoValue;
import java.io.IOException;
@@ -36,6 +37,7 @@
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
@@ -80,6 +82,48 @@
*
* }
*
+ *
Reading with Metadata from a MQTT broker
+ *
+ * The {@code readWithMetadata} method extends the functionality of the basic {@code read} method
+ * by returning a {@link PCollection} of metadata that includes both the topic name and the payload.
+ * The metadata is encapsulated in a container class {@link MqttRecord} that includes the topic name
+ * and payload. This allows you to implement business logic that can differ depending on the topic
+ * from which the message was received.
+ *
+ *
{@code
+ * PCollection records = pipeline.apply(
+ * MqttIO.readWithMetadata()
+ * .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
+ * "tcp://host:11883",
+ * "my_topic_pattern"))
+ *
+ * }
+ *
+ * By using the topic information, you can apply different processing logic depending on the
+ * source topic, enhancing the flexibility of message processing.
+ *
+ *
Example
+ *
+ * {@code
+ * pipeline
+ * .apply(MqttIO.readWithMetadata()
+ * .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
+ * "tcp://host:1883", "my_topic_pattern")))
+ * .apply(ParDo.of(new DoFn() {
+ * @ProcessElement
+ * public void processElement(ProcessContext c) {
+ * MqttRecord record = c.element();
+ * String topic = record.getTopic();
+ * byte[] payload = record.getPayload();
+ * // Apply business logic based on the topic
+ * if (topic.equals("important_topic")) {
+ * // Special processing for important_topic
+ * }
+ * }
+ * }));
+ *
+ * }
+ *
* Writing to a MQTT broker
*
* MqttIO sink supports writing {@code byte[]} to a topic on a MQTT broker.
@@ -130,9 +174,18 @@ public class MqttIO {
private static final Logger LOG = LoggerFactory.getLogger(MqttIO.class);
private static final int MQTT_3_1_MAX_CLIENT_ID_LENGTH = 23;
- public static Read read() {
- return new AutoValue_MqttIO_Read.Builder()
+ public static Read read() {
+ return new AutoValue_MqttIO_Read.Builder()
.setMaxReadTime(null)
+ .setWithMetadata(false)
+ .setMaxNumRecords(Long.MAX_VALUE)
+ .build();
+ }
+
+ public static Read readWithMetadata() {
+ return new AutoValue_MqttIO_Read.Builder()
+ .setMaxReadTime(null)
+ .setWithMetadata(true)
.setMaxNumRecords(Long.MAX_VALUE)
.build();
}
@@ -267,7 +320,7 @@ private MQTT createClient() throws Exception {
/** A {@link PTransform} to read from a MQTT broker. */
@AutoValue
- public abstract static class Read extends PTransform> {
+ public abstract static class Read extends PTransform> {
abstract @Nullable ConnectionConfiguration connectionConfiguration();
@@ -275,21 +328,29 @@ public abstract static class Read extends PTransform
abstract @Nullable Duration maxReadTime();
- abstract Builder builder();
+ abstract Builder builder();
+
+ abstract boolean withMetadata();
+
+ abstract @Nullable Coder coder();
@AutoValue.Builder
- abstract static class Builder {
- abstract Builder setConnectionConfiguration(ConnectionConfiguration config);
+ abstract static class Builder {
+ abstract Builder setConnectionConfiguration(ConnectionConfiguration config);
+
+ abstract Builder setMaxNumRecords(long maxNumRecords);
- abstract Builder setMaxNumRecords(long maxNumRecords);
+ abstract Builder setMaxReadTime(Duration maxReadTime);
- abstract Builder setMaxReadTime(Duration maxReadTime);
+ abstract Builder setWithMetadata(boolean withMetadata);
- abstract Read build();
+ abstract Builder setCoder(Coder coder);
+
+ abstract Read build();
}
/** Define the MQTT connection configuration used to connect to the MQTT broker. */
- public Read withConnectionConfiguration(ConnectionConfiguration configuration) {
+ public Read withConnectionConfiguration(ConnectionConfiguration configuration) {
checkArgument(configuration != null, "configuration can not be null");
return builder().setConnectionConfiguration(configuration).build();
}
@@ -299,7 +360,7 @@ public Read withConnectionConfiguration(ConnectionConfiguration configuration) {
* records is lower than {@code Long.MAX_VALUE}, the {@link Read} will provide a bounded {@link
* PCollection}.
*/
- public Read withMaxNumRecords(long maxNumRecords) {
+ public Read withMaxNumRecords(long maxNumRecords) {
return builder().setMaxNumRecords(maxNumRecords).build();
}
@@ -307,19 +368,33 @@ public Read withMaxNumRecords(long maxNumRecords) {
* Define the max read time (duration) while the {@link Read} will receive messages. When this
* max read time is not null, the {@link Read} will provide a bounded {@link PCollection}.
*/
- public Read withMaxReadTime(Duration maxReadTime) {
+ public Read withMaxReadTime(Duration maxReadTime) {
return builder().setMaxReadTime(maxReadTime).build();
}
@Override
- public PCollection expand(PBegin input) {
+ @SuppressWarnings("unchecked")
+ public PCollection expand(PBegin input) {
checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null");
checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null");
- org.apache.beam.sdk.io.Read.Unbounded unbounded =
- org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this));
+ Coder coder;
+ if (withMetadata()) {
+ try {
+ coder =
+ (Coder) input.getPipeline().getSchemaRegistry().getSchemaCoder(MqttRecord.class);
+ } catch (NoSuchSchemaException e) {
+ throw new RuntimeException(e.getMessage());
+ }
+ } else {
+ coder = (Coder) ByteArrayCoder.of();
+ }
+
+ org.apache.beam.sdk.io.Read.Unbounded unbounded =
+ org.apache.beam.sdk.io.Read.from(
+ new UnboundedMqttSource<>(this.builder().setCoder(coder).build()));
- PTransform> transform = unbounded;
+ PTransform> transform = unbounded;
if (maxNumRecords() < Long.MAX_VALUE || maxReadTime() != null) {
transform = unbounded.withMaxReadTime(maxReadTime()).withMaxNumRecords(maxNumRecords());
@@ -403,27 +478,39 @@ public int hashCode() {
}
@VisibleForTesting
- static class UnboundedMqttSource extends UnboundedSource {
+ static class UnboundedMqttSource extends UnboundedSource {
- private final Read spec;
+ private final Read spec;
- public UnboundedMqttSource(Read spec) {
+ public UnboundedMqttSource(Read spec) {
this.spec = spec;
}
@Override
- public UnboundedReader createReader(
+ @SuppressWarnings("unchecked")
+ public UnboundedReader createReader(
PipelineOptions options, MqttCheckpointMark checkpointMark) {
- return new UnboundedMqttReader(this, checkpointMark);
+ final UnboundedMqttReader unboundedMqttReader;
+ if (spec.withMetadata()) {
+ unboundedMqttReader =
+ new UnboundedMqttReader<>(
+ this,
+ checkpointMark,
+ message -> (T) MqttRecord.of(message.getTopic(), message.getPayload()));
+ } else {
+ unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark);
+ }
+
+ return unboundedMqttReader;
}
@Override
- public List split(int desiredNumSplits, PipelineOptions options) {
+ public List> split(int desiredNumSplits, PipelineOptions options) {
// MQTT is based on a pub/sub pattern
// so, if we create several subscribers on the same topic, they all will receive the same
// message, resulting to duplicate messages in the PCollection.
// So, for MQTT, we limit to number of split ot 1 (unique source).
- return Collections.singletonList(new UnboundedMqttSource(spec));
+ return Collections.singletonList(new UnboundedMqttSource<>(spec));
}
@Override
@@ -437,23 +524,24 @@ public Coder getCheckpointMarkCoder() {
}
@Override
- public Coder getOutputCoder() {
- return ByteArrayCoder.of();
+ public Coder getOutputCoder() {
+ return checkNotNull(this.spec.coder(), "coder can not be null");
}
}
@VisibleForTesting
- static class UnboundedMqttReader extends UnboundedSource.UnboundedReader {
+ static class UnboundedMqttReader extends UnboundedSource.UnboundedReader {
- private final UnboundedMqttSource source;
+ private final UnboundedMqttSource source;
private MQTT client;
private BlockingConnection connection;
- private byte[] current;
+ private T current;
private Instant currentTimestamp;
private MqttCheckpointMark checkpointMark;
+ private SerializableFunction extractFn;
- public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) {
+ public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) {
this.source = source;
this.current = null;
if (checkpointMark != null) {
@@ -461,12 +549,21 @@ public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkp
} else {
this.checkpointMark = new MqttCheckpointMark();
}
+ this.extractFn = message -> (T) message.getPayload();
+ }
+
+ public UnboundedMqttReader(
+ UnboundedMqttSource source,
+ MqttCheckpointMark checkpointMark,
+ SerializableFunction extractFn) {
+ this(source, checkpointMark);
+ this.extractFn = extractFn;
}
@Override
public boolean start() throws IOException {
LOG.debug("Starting MQTT reader ...");
- Read spec = source.spec;
+ Read spec = source.spec;
try {
client = spec.connectionConfiguration().createClient();
LOG.debug("Reader client ID is {}", client.getClientId());
@@ -488,7 +585,7 @@ public boolean advance() throws IOException {
if (message == null) {
return false;
}
- current = message.getPayload();
+ current = this.extractFn.apply(message);
currentTimestamp = Instant.now();
checkpointMark.add(message, currentTimestamp);
} catch (Exception e) {
@@ -520,7 +617,7 @@ public UnboundedSource.CheckpointMark getCheckpointMark() {
}
@Override
- public byte[] getCurrent() {
+ public T getCurrent() {
if (current == null) {
throw new NoSuchElementException();
}
@@ -536,7 +633,7 @@ public Instant getCurrentTimestamp() {
}
@Override
- public UnboundedMqttSource getCurrentSource() {
+ public UnboundedMqttSource getCurrentSource() {
return source;
}
}
diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java
new file mode 100644
index 000000000000..bbf27f5c73e7
--- /dev/null
+++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.mqtt;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.sdk.schemas.AutoValueSchema;
+import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
+
+/** A container class for MQTT message metadata, including the topic name and payload. */
+@DefaultSchema(AutoValueSchema.class)
+@AutoValue
+public abstract class MqttRecord {
+ public abstract String getTopic();
+
+ @SuppressWarnings("mutable")
+ public abstract byte[] getPayload();
+
+ static Builder builder() {
+ return new AutoValue_MqttRecord.Builder();
+ }
+
+ static MqttRecord of(String topic, byte[] payload) {
+ return builder().setTopic(topic).setPayload(payload).build();
+ }
+
+ @AutoValue.Builder
+ abstract static class Builder {
+ abstract Builder setTopic(String topic);
+
+ abstract Builder setPayload(byte[] payload);
+
+ abstract MqttRecord build();
+ }
+}
diff --git a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
index 8dfa7838d66a..64b0728c879a 100644
--- a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
+++ b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
@@ -44,6 +44,7 @@
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
@@ -93,7 +94,7 @@ public void startBroker() throws Exception {
@Ignore("https://github.com/apache/beam/issues/18723 Test timeout failure.")
public void testReadNoClientId() throws Exception {
final String topicName = "READ_TOPIC_NO_CLIENT_ID";
- Read mqttReader =
+ Read mqttReader =
MqttIO.read()
.withConnectionConfiguration(
MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port, topicName))
@@ -214,6 +215,81 @@ public void testRead() throws Exception {
publishConnection.disconnect();
}
+ @Test(timeout = 60 * 1000)
+ public void testReadWithMetadata() throws Exception {
+ final String wildcardTopic = "topic/#";
+ final String topic1 = "topic/1";
+ final String topic2 = "topic/2";
+
+ final PTransform> mqttReaderWithMetadata =
+ MqttIO.readWithMetadata()
+ .withConnectionConfiguration(
+ MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port, wildcardTopic))
+ .withMaxNumRecords(10);
+
+ final PCollection output = pipeline.apply(mqttReaderWithMetadata);
+ PAssert.that(output)
+ .containsInAnyOrder(
+ MqttRecord.of(topic1, "This is test 0".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic1, "This is test 1".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic1, "This is test 2".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic1, "This is test 3".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic1, "This is test 4".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic2, "This is test 5".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic2, "This is test 6".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic2, "This is test 7".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic2, "This is test 8".getBytes(StandardCharsets.UTF_8)),
+ MqttRecord.of(topic2, "This is test 9".getBytes(StandardCharsets.UTF_8)));
+
+ // produce messages on the brokerService in another thread
+ // This thread prevents to block the pipeline waiting for new messages
+ MQTT client = new MQTT();
+ client.setHost("tcp://localhost:" + port);
+ final BlockingConnection publishConnection = client.blockingConnection();
+ publishConnection.connect();
+ Thread publisherThread =
+ new Thread(
+ () -> {
+ try {
+ LOG.info(
+ "Waiting pipeline connected to the MQTT broker before sending "
+ + "messages ...");
+ boolean pipelineConnected = false;
+ while (!pipelineConnected) {
+ Thread.sleep(1000);
+ for (Connection connection : brokerService.getBroker().getClients()) {
+ if (!connection.getConnectionId().isEmpty()) {
+ pipelineConnected = true;
+ }
+ }
+ }
+ for (int i = 0; i < 5; i++) {
+ publishConnection.publish(
+ topic1,
+ ("This is test " + i).getBytes(StandardCharsets.UTF_8),
+ QoS.EXACTLY_ONCE,
+ false);
+ }
+ for (int i = 5; i < 10; i++) {
+ publishConnection.publish(
+ topic2,
+ ("This is test " + i).getBytes(StandardCharsets.UTF_8),
+ QoS.EXACTLY_ONCE,
+ false);
+ }
+
+ } catch (Exception e) {
+ // nothing to do
+ }
+ });
+
+ publisherThread.start();
+ pipeline.run();
+
+ publishConnection.disconnect();
+ publisherThread.join();
+ }
+
/** Test for BEAM-3282: this test should not timeout. */
@Test(timeout = 30 * 1000)
public void testReceiveWithTimeoutAndNoData() throws Exception {