diff --git a/CHANGES.md b/CHANGES.md index be4e0ba4d0f6..1d1d3bcdbd3c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -66,6 +66,7 @@ ## New Features / Improvements +* Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195)) * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java index e1868e2c8461..efc51362d06a 100644 --- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.mqtt; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.google.auto.value.AutoValue; import java.io.IOException; @@ -36,6 +37,7 @@ import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -80,6 +82,48 @@ * * } * + *

Reading with Metadata from a MQTT broker

+ * + *

The {@code readWithMetadata} method extends the functionality of the basic {@code read} method + * by returning a {@link PCollection} of metadata that includes both the topic name and the payload. + * The metadata is encapsulated in a container class {@link MqttRecord} that includes the topic name + * and payload. This allows you to implement business logic that can differ depending on the topic + * from which the message was received. + * + *

{@code
+ * PCollection records = pipeline.apply(
+ *   MqttIO.readWithMetadata()
+ *    .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
+ *      "tcp://host:11883",
+ *      "my_topic_pattern"))
+ *
+ * }
+ * + *

By using the topic information, you can apply different processing logic depending on the + * source topic, enhancing the flexibility of message processing. + * + *

Example

+ * + *
{@code
+ * pipeline
+ *   .apply(MqttIO.readWithMetadata()
+ *     .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
+ *       "tcp://host:1883", "my_topic_pattern")))
+ *   .apply(ParDo.of(new DoFn() {
+ *     @ProcessElement
+ *     public void processElement(ProcessContext c) {
+ *       MqttRecord record = c.element();
+ *       String topic = record.getTopic();
+ *       byte[] payload = record.getPayload();
+ *       // Apply business logic based on the topic
+ *       if (topic.equals("important_topic")) {
+ *         // Special processing for important_topic
+ *       }
+ *     }
+ *   }));
+ *
+ * }
+ * *

Writing to a MQTT broker

* *

MqttIO sink supports writing {@code byte[]} to a topic on a MQTT broker. @@ -130,9 +174,18 @@ public class MqttIO { private static final Logger LOG = LoggerFactory.getLogger(MqttIO.class); private static final int MQTT_3_1_MAX_CLIENT_ID_LENGTH = 23; - public static Read read() { - return new AutoValue_MqttIO_Read.Builder() + public static Read read() { + return new AutoValue_MqttIO_Read.Builder() .setMaxReadTime(null) + .setWithMetadata(false) + .setMaxNumRecords(Long.MAX_VALUE) + .build(); + } + + public static Read readWithMetadata() { + return new AutoValue_MqttIO_Read.Builder() + .setMaxReadTime(null) + .setWithMetadata(true) .setMaxNumRecords(Long.MAX_VALUE) .build(); } @@ -267,7 +320,7 @@ private MQTT createClient() throws Exception { /** A {@link PTransform} to read from a MQTT broker. */ @AutoValue - public abstract static class Read extends PTransform> { + public abstract static class Read extends PTransform> { abstract @Nullable ConnectionConfiguration connectionConfiguration(); @@ -275,21 +328,29 @@ public abstract static class Read extends PTransform abstract @Nullable Duration maxReadTime(); - abstract Builder builder(); + abstract Builder builder(); + + abstract boolean withMetadata(); + + abstract @Nullable Coder coder(); @AutoValue.Builder - abstract static class Builder { - abstract Builder setConnectionConfiguration(ConnectionConfiguration config); + abstract static class Builder { + abstract Builder setConnectionConfiguration(ConnectionConfiguration config); + + abstract Builder setMaxNumRecords(long maxNumRecords); - abstract Builder setMaxNumRecords(long maxNumRecords); + abstract Builder setMaxReadTime(Duration maxReadTime); - abstract Builder setMaxReadTime(Duration maxReadTime); + abstract Builder setWithMetadata(boolean withMetadata); - abstract Read build(); + abstract Builder setCoder(Coder coder); + + abstract Read build(); } /** Define the MQTT connection configuration used to connect to the MQTT broker. */ - public Read withConnectionConfiguration(ConnectionConfiguration configuration) { + public Read withConnectionConfiguration(ConnectionConfiguration configuration) { checkArgument(configuration != null, "configuration can not be null"); return builder().setConnectionConfiguration(configuration).build(); } @@ -299,7 +360,7 @@ public Read withConnectionConfiguration(ConnectionConfiguration configuration) { * records is lower than {@code Long.MAX_VALUE}, the {@link Read} will provide a bounded {@link * PCollection}. */ - public Read withMaxNumRecords(long maxNumRecords) { + public Read withMaxNumRecords(long maxNumRecords) { return builder().setMaxNumRecords(maxNumRecords).build(); } @@ -307,19 +368,33 @@ public Read withMaxNumRecords(long maxNumRecords) { * Define the max read time (duration) while the {@link Read} will receive messages. When this * max read time is not null, the {@link Read} will provide a bounded {@link PCollection}. */ - public Read withMaxReadTime(Duration maxReadTime) { + public Read withMaxReadTime(Duration maxReadTime) { return builder().setMaxReadTime(maxReadTime).build(); } @Override - public PCollection expand(PBegin input) { + @SuppressWarnings("unchecked") + public PCollection expand(PBegin input) { checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null"); checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null"); - org.apache.beam.sdk.io.Read.Unbounded unbounded = - org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this)); + Coder coder; + if (withMetadata()) { + try { + coder = + (Coder) input.getPipeline().getSchemaRegistry().getSchemaCoder(MqttRecord.class); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e.getMessage()); + } + } else { + coder = (Coder) ByteArrayCoder.of(); + } + + org.apache.beam.sdk.io.Read.Unbounded unbounded = + org.apache.beam.sdk.io.Read.from( + new UnboundedMqttSource<>(this.builder().setCoder(coder).build())); - PTransform> transform = unbounded; + PTransform> transform = unbounded; if (maxNumRecords() < Long.MAX_VALUE || maxReadTime() != null) { transform = unbounded.withMaxReadTime(maxReadTime()).withMaxNumRecords(maxNumRecords()); @@ -403,27 +478,39 @@ public int hashCode() { } @VisibleForTesting - static class UnboundedMqttSource extends UnboundedSource { + static class UnboundedMqttSource extends UnboundedSource { - private final Read spec; + private final Read spec; - public UnboundedMqttSource(Read spec) { + public UnboundedMqttSource(Read spec) { this.spec = spec; } @Override - public UnboundedReader createReader( + @SuppressWarnings("unchecked") + public UnboundedReader createReader( PipelineOptions options, MqttCheckpointMark checkpointMark) { - return new UnboundedMqttReader(this, checkpointMark); + final UnboundedMqttReader unboundedMqttReader; + if (spec.withMetadata()) { + unboundedMqttReader = + new UnboundedMqttReader<>( + this, + checkpointMark, + message -> (T) MqttRecord.of(message.getTopic(), message.getPayload())); + } else { + unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark); + } + + return unboundedMqttReader; } @Override - public List split(int desiredNumSplits, PipelineOptions options) { + public List> split(int desiredNumSplits, PipelineOptions options) { // MQTT is based on a pub/sub pattern // so, if we create several subscribers on the same topic, they all will receive the same // message, resulting to duplicate messages in the PCollection. // So, for MQTT, we limit to number of split ot 1 (unique source). - return Collections.singletonList(new UnboundedMqttSource(spec)); + return Collections.singletonList(new UnboundedMqttSource<>(spec)); } @Override @@ -437,23 +524,24 @@ public Coder getCheckpointMarkCoder() { } @Override - public Coder getOutputCoder() { - return ByteArrayCoder.of(); + public Coder getOutputCoder() { + return checkNotNull(this.spec.coder(), "coder can not be null"); } } @VisibleForTesting - static class UnboundedMqttReader extends UnboundedSource.UnboundedReader { + static class UnboundedMqttReader extends UnboundedSource.UnboundedReader { - private final UnboundedMqttSource source; + private final UnboundedMqttSource source; private MQTT client; private BlockingConnection connection; - private byte[] current; + private T current; private Instant currentTimestamp; private MqttCheckpointMark checkpointMark; + private SerializableFunction extractFn; - public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) { + public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) { this.source = source; this.current = null; if (checkpointMark != null) { @@ -461,12 +549,21 @@ public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkp } else { this.checkpointMark = new MqttCheckpointMark(); } + this.extractFn = message -> (T) message.getPayload(); + } + + public UnboundedMqttReader( + UnboundedMqttSource source, + MqttCheckpointMark checkpointMark, + SerializableFunction extractFn) { + this(source, checkpointMark); + this.extractFn = extractFn; } @Override public boolean start() throws IOException { LOG.debug("Starting MQTT reader ..."); - Read spec = source.spec; + Read spec = source.spec; try { client = spec.connectionConfiguration().createClient(); LOG.debug("Reader client ID is {}", client.getClientId()); @@ -488,7 +585,7 @@ public boolean advance() throws IOException { if (message == null) { return false; } - current = message.getPayload(); + current = this.extractFn.apply(message); currentTimestamp = Instant.now(); checkpointMark.add(message, currentTimestamp); } catch (Exception e) { @@ -520,7 +617,7 @@ public UnboundedSource.CheckpointMark getCheckpointMark() { } @Override - public byte[] getCurrent() { + public T getCurrent() { if (current == null) { throw new NoSuchElementException(); } @@ -536,7 +633,7 @@ public Instant getCurrentTimestamp() { } @Override - public UnboundedMqttSource getCurrentSource() { + public UnboundedMqttSource getCurrentSource() { return source; } } diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java new file mode 100644 index 000000000000..bbf27f5c73e7 --- /dev/null +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.mqtt; + +import com.google.auto.value.AutoValue; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; + +/** A container class for MQTT message metadata, including the topic name and payload. */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class MqttRecord { + public abstract String getTopic(); + + @SuppressWarnings("mutable") + public abstract byte[] getPayload(); + + static Builder builder() { + return new AutoValue_MqttRecord.Builder(); + } + + static MqttRecord of(String topic, byte[] payload) { + return builder().setTopic(topic).setPayload(payload).build(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setTopic(String topic); + + abstract Builder setPayload(byte[] payload); + + abstract MqttRecord build(); + } +} diff --git a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java index 8dfa7838d66a..64b0728c879a 100644 --- a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java +++ b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java @@ -44,6 +44,7 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -93,7 +94,7 @@ public void startBroker() throws Exception { @Ignore("https://github.com/apache/beam/issues/18723 Test timeout failure.") public void testReadNoClientId() throws Exception { final String topicName = "READ_TOPIC_NO_CLIENT_ID"; - Read mqttReader = + Read mqttReader = MqttIO.read() .withConnectionConfiguration( MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port, topicName)) @@ -214,6 +215,81 @@ public void testRead() throws Exception { publishConnection.disconnect(); } + @Test(timeout = 60 * 1000) + public void testReadWithMetadata() throws Exception { + final String wildcardTopic = "topic/#"; + final String topic1 = "topic/1"; + final String topic2 = "topic/2"; + + final PTransform> mqttReaderWithMetadata = + MqttIO.readWithMetadata() + .withConnectionConfiguration( + MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port, wildcardTopic)) + .withMaxNumRecords(10); + + final PCollection output = pipeline.apply(mqttReaderWithMetadata); + PAssert.that(output) + .containsInAnyOrder( + MqttRecord.of(topic1, "This is test 0".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 1".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 2".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 3".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 4".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 5".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 6".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 7".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 8".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 9".getBytes(StandardCharsets.UTF_8))); + + // produce messages on the brokerService in another thread + // This thread prevents to block the pipeline waiting for new messages + MQTT client = new MQTT(); + client.setHost("tcp://localhost:" + port); + final BlockingConnection publishConnection = client.blockingConnection(); + publishConnection.connect(); + Thread publisherThread = + new Thread( + () -> { + try { + LOG.info( + "Waiting pipeline connected to the MQTT broker before sending " + + "messages ..."); + boolean pipelineConnected = false; + while (!pipelineConnected) { + Thread.sleep(1000); + for (Connection connection : brokerService.getBroker().getClients()) { + if (!connection.getConnectionId().isEmpty()) { + pipelineConnected = true; + } + } + } + for (int i = 0; i < 5; i++) { + publishConnection.publish( + topic1, + ("This is test " + i).getBytes(StandardCharsets.UTF_8), + QoS.EXACTLY_ONCE, + false); + } + for (int i = 5; i < 10; i++) { + publishConnection.publish( + topic2, + ("This is test " + i).getBytes(StandardCharsets.UTF_8), + QoS.EXACTLY_ONCE, + false); + } + + } catch (Exception e) { + // nothing to do + } + }); + + publisherThread.start(); + pipeline.run(); + + publishConnection.disconnect(); + publisherThread.join(); + } + /** Test for BEAM-3282: this test should not timeout. */ @Test(timeout = 30 * 1000) public void testReceiveWithTimeoutAndNoData() throws Exception {