Skip to content

Commit

Permalink
Kafka SchemaTransform translation (#31362)
Browse files Browse the repository at this point in the history
* kafka schematransform translation and tests

* switch existing schematransform tests to use Managed API
  • Loading branch information
ahmedabu98 authored Jun 5, 2024
1 parent 9c9de49 commit f2931d3
Show file tree
Hide file tree
Showing 8 changed files with 521 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2665,7 +2665,7 @@ abstract static class Builder<K, V> {
abstract Builder<K, V> setProducerFactoryFn(
@Nullable SerializableFunction<Map<String, Object>, Producer<K, V>> fn);

abstract Builder<K, V> setKeySerializer(Class<? extends Serializer<K>> serializer);
abstract Builder<K, V> setKeySerializer(@Nullable Class<? extends Serializer<K>> serializer);

abstract Builder<K, V> setValueSerializer(Class<? extends Serializer<V>> serializer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ public static Builder builder() {
/** Sets the topic from which to read. */
public abstract String getTopic();

@SchemaFieldDescription("Upper bound of how long to read from Kafka.")
@Nullable
public abstract Integer getMaxReadTimeSeconds();

@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
@Nullable
public abstract ErrorHandling getErrorHandling();
Expand Down Expand Up @@ -179,6 +183,8 @@ public abstract static class Builder {
/** Sets the topic from which to read. */
public abstract Builder setTopic(String value);

public abstract Builder setMaxReadTimeSeconds(Integer maxReadTimeSeconds);

public abstract Builder setErrorHandling(ErrorHandling errorHandling);

/** Builds a {@link KafkaReadSchemaTransformConfiguration} instance. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.io.kafka;

import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;

import com.google.auto.service.AutoService;
import java.io.FileOutputStream;
import java.io.IOException;
Expand All @@ -38,7 +40,9 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.Convert;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
Expand All @@ -56,7 +60,6 @@
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
Expand All @@ -76,19 +79,6 @@ public class KafkaReadSchemaTransformProvider
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};

final Boolean isTest;
final Integer testTimeoutSecs;

public KafkaReadSchemaTransformProvider() {
this(false, 0);
}

@VisibleForTesting
KafkaReadSchemaTransformProvider(Boolean isTest, Integer testTimeoutSecs) {
this.isTest = isTest;
this.testTimeoutSecs = testTimeoutSecs;
}

@Override
protected Class<KafkaReadSchemaTransformConfiguration> configurationClass() {
return KafkaReadSchemaTransformConfiguration.class;
Expand All @@ -99,113 +89,7 @@ protected Class<KafkaReadSchemaTransformConfiguration> configurationClass() {
})
@Override
protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configuration) {
configuration.validate();

final String inputSchema = configuration.getSchema();
final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
final String autoOffsetReset =
MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), "latest");

Map<String, Object> consumerConfigs =
new HashMap<>(
MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new HashMap<>()));
consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-" + groupId);
consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset);

String format = configuration.getFormat();
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());

SerializableFunction<byte[], Row> valueMapper;
Schema beamSchema;

String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl();
if (confluentSchemaRegUrl != null) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
final String confluentSchemaRegSubject =
configuration.getConfluentSchemaRegistrySubject();
KafkaIO.Read<byte[], GenericRecord> kafkaRead =
KafkaIO.<byte[], GenericRecord>read()
.withTopic(configuration.getTopic())
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withBootstrapServers(configuration.getBootstrapServers())
.withConsumerConfigUpdates(consumerConfigs)
.withKeyDeserializer(ByteArrayDeserializer.class)
.withValueDeserializer(
ConfluentSchemaRegistryDeserializerProvider.of(
confluentSchemaRegUrl, confluentSchemaRegSubject));
if (isTest) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
}

PCollection<GenericRecord> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

assert kafkaValues.getCoder().getClass() == AvroCoder.class;
AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>) kafkaValues.getCoder();
kafkaValues = kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
return PCollectionRowTuple.of("output", kafkaValues.apply(Convert.toRows()));
}
};
}
if ("RAW".equals(format)) {
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
} else if ("PROTO".equals(format)) {
String fileDescriptorPath = configuration.getFileDescriptorPath();
String messageName = configuration.getMessageName();
if (fileDescriptorPath != null) {
beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName);
valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
} else {
beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema, messageName);
valueMapper = ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName);
}
} else if ("JSON".equals(format)) {
beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema);
valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
} else {
beamSchema = AvroUtils.toBeamSchema(new org.apache.avro.Schema.Parser().parse(inputSchema));
valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
}

return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
KafkaIO.Read<byte[], byte[]> kafkaRead =
KafkaIO.readBytes()
.withConsumerConfigUpdates(consumerConfigs)
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withTopic(configuration.getTopic())
.withBootstrapServers(configuration.getBootstrapServers());
if (isTest) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
}

PCollection<byte[]> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

Schema errorSchema = ErrorHandling.errorSchemaBytes();
PCollectionTuple outputTuple =
kafkaValues.apply(
ParDo.of(
new ErrorFn(
"Kafka-read-error-counter", valueMapper, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

PCollectionRowTuple outputRows =
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));

PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows = outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput);
}
return outputRows;
}
};
return new KafkaReadSchemaTransform(configuration);
}

public static SerializableFunction<byte[], Row> getRawBytesToRowFunction(Schema rawSchema) {
Expand All @@ -232,6 +116,140 @@ public List<String> outputCollectionNames() {
return Arrays.asList("output", "errors");
}

static class KafkaReadSchemaTransform extends SchemaTransform {
private final KafkaReadSchemaTransformConfiguration configuration;

KafkaReadSchemaTransform(KafkaReadSchemaTransformConfiguration configuration) {
this.configuration = configuration;
}

Row getConfigurationRow() {
try {
// To stay consistent with our SchemaTransform configuration naming conventions,
// we sort lexicographically
return SchemaRegistry.createDefault()
.getToRowFunction(KafkaReadSchemaTransformConfiguration.class)
.apply(configuration)
.sorted()
.toSnakeCase();
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e);
}
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
configuration.validate();

final String inputSchema = configuration.getSchema();
final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
final String autoOffsetReset =
MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), "latest");

Map<String, Object> consumerConfigs =
new HashMap<>(
MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new HashMap<>()));
consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-" + groupId);
consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset);

String format = configuration.getFormat();
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());

SerializableFunction<byte[], Row> valueMapper;
Schema beamSchema;

String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl();
if (confluentSchemaRegUrl != null) {
final String confluentSchemaRegSubject =
checkArgumentNotNull(configuration.getConfluentSchemaRegistrySubject());
KafkaIO.Read<byte[], GenericRecord> kafkaRead =
KafkaIO.<byte[], GenericRecord>read()
.withTopic(configuration.getTopic())
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withBootstrapServers(configuration.getBootstrapServers())
.withConsumerConfigUpdates(consumerConfigs)
.withKeyDeserializer(ByteArrayDeserializer.class)
.withValueDeserializer(
ConfluentSchemaRegistryDeserializerProvider.of(
confluentSchemaRegUrl, confluentSchemaRegSubject));
Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
if (maxReadTimeSeconds != null) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
}

PCollection<GenericRecord> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

assert kafkaValues.getCoder().getClass() == AvroCoder.class;
AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>) kafkaValues.getCoder();
kafkaValues = kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
return PCollectionRowTuple.of("output", kafkaValues.apply(Convert.toRows()));
}

if ("RAW".equals(format)) {
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
} else if ("PROTO".equals(format)) {
String fileDescriptorPath = configuration.getFileDescriptorPath();
String messageName = checkArgumentNotNull(configuration.getMessageName());
if (fileDescriptorPath != null) {
beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName);
valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
} else {
beamSchema =
ProtoByteUtils.getBeamSchemaFromProtoSchema(
checkArgumentNotNull(inputSchema), messageName);
valueMapper =
ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(
checkArgumentNotNull(inputSchema), messageName);
}
} else if ("JSON".equals(format)) {
beamSchema = JsonUtils.beamSchemaFromJsonSchema(checkArgumentNotNull(inputSchema));
valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
} else {
beamSchema =
AvroUtils.toBeamSchema(
new org.apache.avro.Schema.Parser().parse(checkArgumentNotNull(inputSchema)));
valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
}

KafkaIO.Read<byte[], byte[]> kafkaRead =
KafkaIO.readBytes()
.withConsumerConfigUpdates(consumerConfigs)
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withTopic(configuration.getTopic())
.withBootstrapServers(configuration.getBootstrapServers());
Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
if (maxReadTimeSeconds != null) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
}

PCollection<byte[]> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

Schema errorSchema = ErrorHandling.errorSchemaBytes();
PCollectionTuple outputTuple =
kafkaValues.apply(
ParDo.of(
new ErrorFn(
"Kafka-read-error-counter", valueMapper, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

PCollectionRowTuple outputRows =
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));

PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows =
outputRows.and(
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput);
}
return outputRows;
}
}

public static class ErrorFn extends DoFn<byte[], Row> {
private final SerializableFunction<byte[], Row> valueMapper;
private final Counter errorCounter;
Expand Down
Loading

0 comments on commit f2931d3

Please sign in to comment.