Skip to content

Commit

Permalink
refactor : change to use SchemaCoder in MqttIO
Browse files Browse the repository at this point in the history
- remove MqttRecordCoder
- refactor MqttRecord to use AutoValueSchema
- change related test
  • Loading branch information
twosom committed Oct 7, 2024
1 parent f3969be commit 64c51a4
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.mqtt;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;

import com.google.auto.value.AutoValue;
import java.io.IOException;
Expand All @@ -36,6 +37,7 @@
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
Expand Down Expand Up @@ -330,6 +332,8 @@ public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>>

abstract boolean withMetadata();

abstract @Nullable Coder<T> coder();

@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setConnectionConfiguration(ConnectionConfiguration config);
Expand All @@ -340,6 +344,8 @@ abstract static class Builder<T> {

abstract Builder<T> setWithMetadata(boolean withMetadata);

abstract Builder<T> setCoder(Coder<T> coder);

abstract Read<T> build();
}

Expand Down Expand Up @@ -367,12 +373,26 @@ public Read<T> withMaxReadTime(Duration maxReadTime) {
}

@Override
@SuppressWarnings("unchecked")
public PCollection<T> expand(PBegin input) {
checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null");
checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null");

Coder<T> coder;
if (withMetadata()) {
try {
coder =
(Coder<T>) input.getPipeline().getSchemaRegistry().getSchemaCoder(MqttRecord.class);
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e.getMessage());
}
} else {
coder = (Coder<T>) ByteArrayCoder.of();
}

org.apache.beam.sdk.io.Read.Unbounded<T> unbounded =
org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource<>(this));
org.apache.beam.sdk.io.Read.from(
new UnboundedMqttSource<>(this.builder().setCoder(coder).build()));

PTransform<PBegin, PCollection<T>> transform = unbounded;

Expand Down Expand Up @@ -476,7 +496,7 @@ public UnboundedReader<T> createReader(
new UnboundedMqttReader<>(
this,
checkpointMark,
message -> (T) new MqttRecord(message.getTopic(), message.getPayload()));
message -> (T) MqttRecord.of(message.getTopic(), message.getPayload()));
} else {
unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark);
}
Expand Down Expand Up @@ -504,13 +524,8 @@ public Coder<MqttCheckpointMark> getCheckpointMarkCoder() {
}

@Override
@SuppressWarnings("unchecked")
public Coder<T> getOutputCoder() {
if (spec.withMetadata()) {
return (Coder<T>) MqttRecordCoder.of();
} else {
return (Coder<T>) ByteArrayCoder.of();
}
return checkNotNull(this.spec.coder(), "coder can not be null");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,33 @@
*/
package org.apache.beam.sdk.io.mqtt;

import java.util.Arrays;
import java.util.Objects;
import org.checkerframework.checker.nullness.qual.Nullable;
import com.google.auto.value.AutoValue;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;

/** A container class for MQTT message metadata, including the topic name and payload. */
public class MqttRecord {
private final String topic;
private final byte[] payload;
@DefaultSchema(AutoValueSchema.class)
@AutoValue
public abstract class MqttRecord {
public abstract String getTopic();

public MqttRecord(String topic, byte[] payload) {
this.topic = topic;
this.payload = payload;
}
@SuppressWarnings("mutable")
public abstract byte[] getPayload();

public String getTopic() {
return topic;
static Builder builder() {
return new AutoValue_MqttRecord.Builder();
}

public byte[] getPayload() {
return payload;
static MqttRecord of(String topic, byte[] payload) {
return builder().setTopic(topic).setPayload(payload).build();
}

@Override
public int hashCode() {
return Objects.hash(topic, Arrays.hashCode(payload));
}
@AutoValue.Builder
abstract static class Builder {
abstract Builder setTopic(String topic);

abstract Builder setPayload(byte[] payload);

@Override
public boolean equals(@Nullable Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MqttRecord that = (MqttRecord) o;
return Objects.equals(topic, that.topic) && Objects.deepEquals(payload, that.payload);
abstract MqttRecord build();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@ public void testReadWithMetadata() throws Exception {
final PCollection<MqttRecord> output = pipeline.apply(mqttReaderWithMetadata);
PAssert.that(output)
.containsInAnyOrder(
new MqttRecord(topic1, "This is test 0".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic1, "This is test 1".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic1, "This is test 2".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic1, "This is test 3".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic1, "This is test 4".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic2, "This is test 5".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic2, "This is test 6".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic2, "This is test 7".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic2, "This is test 8".getBytes(StandardCharsets.UTF_8)),
new MqttRecord(topic2, "This is test 9".getBytes(StandardCharsets.UTF_8)));
MqttRecord.of(topic1, "This is test 0".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic1, "This is test 1".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic1, "This is test 2".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic1, "This is test 3".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic1, "This is test 4".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic2, "This is test 5".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic2, "This is test 6".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic2, "This is test 7".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic2, "This is test 8".getBytes(StandardCharsets.UTF_8)),
MqttRecord.of(topic2, "This is test 9".getBytes(StandardCharsets.UTF_8)));

// produce messages on the brokerService in another thread
// This thread prevents to block the pipeline waiting for new messages
Expand Down

0 comments on commit 64c51a4

Please sign in to comment.