Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kafka SchemaTransform translation #31362

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2665,7 +2665,7 @@ abstract static class Builder<K, V> {
abstract Builder<K, V> setProducerFactoryFn(
SerializableFunction<Map<String, Object>, Producer<K, V>> fn);

abstract Builder<K, V> setKeySerializer(Class<? extends Serializer<K>> serializer);
abstract Builder<K, V> setKeySerializer(@Nullable Class<? extends Serializer<K>> serializer);

abstract Builder<K, V> setValueSerializer(Class<? extends Serializer<V>> serializer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ public static Builder builder() {
/** Sets the topic from which to read. */
public abstract String getTopic();

@SchemaFieldDescription("Upper bound of how long to read from Kafka.")
@Nullable
public abstract Integer getMaxReadTimeSeconds();

@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
@Nullable
public abstract ErrorHandling getErrorHandling();
Expand Down Expand Up @@ -179,6 +183,8 @@ public abstract static class Builder {
/** Sets the topic from which to read. */
public abstract Builder setTopic(String value);

public abstract Builder setMaxReadTimeSeconds(Integer maxReadTimeSeconds);

public abstract Builder setErrorHandling(ErrorHandling errorHandling);

/** Builds a {@link KafkaReadSchemaTransformConfiguration} instance. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.io.kafka;

import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;

import com.google.auto.service.AutoService;
import java.io.FileOutputStream;
import java.io.IOException;
Expand All @@ -38,7 +40,9 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.Convert;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
Expand All @@ -56,7 +60,6 @@
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
Expand All @@ -76,19 +79,6 @@ public class KafkaReadSchemaTransformProvider
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};

final Boolean isTest;
final Integer testTimeoutSecs;

public KafkaReadSchemaTransformProvider() {
this(false, 0);
}

@VisibleForTesting
KafkaReadSchemaTransformProvider(Boolean isTest, Integer testTimeoutSecs) {
this.isTest = isTest;
this.testTimeoutSecs = testTimeoutSecs;
}

@Override
protected Class<KafkaReadSchemaTransformConfiguration> configurationClass() {
return KafkaReadSchemaTransformConfiguration.class;
Expand All @@ -99,113 +89,7 @@ protected Class<KafkaReadSchemaTransformConfiguration> configurationClass() {
})
@Override
protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configuration) {
configuration.validate();

final String inputSchema = configuration.getSchema();
final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
final String autoOffsetReset =
MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), "latest");

Map<String, Object> consumerConfigs =
new HashMap<>(
MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new HashMap<>()));
consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-" + groupId);
consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset);

String format = configuration.getFormat();
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());

SerializableFunction<byte[], Row> valueMapper;
Schema beamSchema;

String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl();
if (confluentSchemaRegUrl != null) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
final String confluentSchemaRegSubject =
configuration.getConfluentSchemaRegistrySubject();
KafkaIO.Read<byte[], GenericRecord> kafkaRead =
KafkaIO.<byte[], GenericRecord>read()
.withTopic(configuration.getTopic())
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withBootstrapServers(configuration.getBootstrapServers())
.withConsumerConfigUpdates(consumerConfigs)
.withKeyDeserializer(ByteArrayDeserializer.class)
.withValueDeserializer(
ConfluentSchemaRegistryDeserializerProvider.of(
confluentSchemaRegUrl, confluentSchemaRegSubject));
if (isTest) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
}

PCollection<GenericRecord> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

assert kafkaValues.getCoder().getClass() == AvroCoder.class;
AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>) kafkaValues.getCoder();
kafkaValues = kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
return PCollectionRowTuple.of("output", kafkaValues.apply(Convert.toRows()));
}
};
}
if ("RAW".equals(format)) {
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
} else if ("PROTO".equals(format)) {
String fileDescriptorPath = configuration.getFileDescriptorPath();
String messageName = configuration.getMessageName();
if (fileDescriptorPath != null) {
beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName);
valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
} else {
beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema, messageName);
valueMapper = ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName);
}
} else if ("JSON".equals(format)) {
beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema);
valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
} else {
beamSchema = AvroUtils.toBeamSchema(new org.apache.avro.Schema.Parser().parse(inputSchema));
valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
}

return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
KafkaIO.Read<byte[], byte[]> kafkaRead =
KafkaIO.readBytes()
.withConsumerConfigUpdates(consumerConfigs)
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withTopic(configuration.getTopic())
.withBootstrapServers(configuration.getBootstrapServers());
if (isTest) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
}

PCollection<byte[]> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

Schema errorSchema = ErrorHandling.errorSchemaBytes();
PCollectionTuple outputTuple =
kafkaValues.apply(
ParDo.of(
new ErrorFn(
"Kafka-read-error-counter", valueMapper, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

PCollectionRowTuple outputRows =
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));

PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows = outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput);
}
return outputRows;
}
};
return new KafkaReadSchemaTransform(configuration);
}

public static SerializableFunction<byte[], Row> getRawBytesToRowFunction(Schema rawSchema) {
Expand All @@ -232,6 +116,140 @@ public List<String> outputCollectionNames() {
return Arrays.asList("output", "errors");
}

static class KafkaReadSchemaTransform extends SchemaTransform {
private final KafkaReadSchemaTransformConfiguration configuration;

KafkaReadSchemaTransform(KafkaReadSchemaTransformConfiguration configuration) {
this.configuration = configuration;
}

Row getConfigurationRow() {
try {
// To stay consistent with our SchemaTransform configuration naming conventions,
// we sort lexicographically
return SchemaRegistry.createDefault()
.getToRowFunction(KafkaReadSchemaTransformConfiguration.class)
.apply(configuration)
.sorted()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify why this sorting is needed ? Do we need to do this for every implementation ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to keep in line with what TypedSchemaTransformProvider does when producing a config schema:

return SchemaRegistry.createDefault().getSchema(configurationClass()).sorted().toSnakeCase();

This is due to the SchemaProvider not always producing a consistent schema (#24361). So we sort to keep it consistent

Do we need to do this for every implementation

Right now unfortunately yes. I'm working on adding some things to SchemaTransform (#30943) to avoid having to copy this everywhere. My hope is this change will make SchemaTransformTranslation sufficient for all and help avoid needing a SchemaTransformTranslation for each IO.

.toSnakeCase();
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e);
}
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
configuration.validate();
ahmedabu98 marked this conversation as resolved.
Show resolved Hide resolved

final String inputSchema = configuration.getSchema();
final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
final String autoOffsetReset =
MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), "latest");

Map<String, Object> consumerConfigs =
new HashMap<>(
MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new HashMap<>()));
consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-" + groupId);
consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset);

String format = configuration.getFormat();
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());

SerializableFunction<byte[], Row> valueMapper;
Schema beamSchema;

String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl();
if (confluentSchemaRegUrl != null) {
final String confluentSchemaRegSubject =
checkArgumentNotNull(configuration.getConfluentSchemaRegistrySubject());
KafkaIO.Read<byte[], GenericRecord> kafkaRead =
KafkaIO.<byte[], GenericRecord>read()
.withTopic(configuration.getTopic())
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withBootstrapServers(configuration.getBootstrapServers())
.withConsumerConfigUpdates(consumerConfigs)
.withKeyDeserializer(ByteArrayDeserializer.class)
.withValueDeserializer(
ConfluentSchemaRegistryDeserializerProvider.of(
confluentSchemaRegUrl, confluentSchemaRegSubject));
Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
if (maxReadTimeSeconds != null) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
}

PCollection<GenericRecord> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

assert kafkaValues.getCoder().getClass() == AvroCoder.class;
AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>) kafkaValues.getCoder();
kafkaValues = kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
return PCollectionRowTuple.of("output", kafkaValues.apply(Convert.toRows()));
}

if ("RAW".equals(format)) {
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
} else if ("PROTO".equals(format)) {
String fileDescriptorPath = configuration.getFileDescriptorPath();
String messageName = checkArgumentNotNull(configuration.getMessageName());
if (fileDescriptorPath != null) {
beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName);
valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
} else {
beamSchema =
ProtoByteUtils.getBeamSchemaFromProtoSchema(
checkArgumentNotNull(inputSchema), messageName);
valueMapper =
ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(
checkArgumentNotNull(inputSchema), messageName);
}
} else if ("JSON".equals(format)) {
beamSchema = JsonUtils.beamSchemaFromJsonSchema(checkArgumentNotNull(inputSchema));
valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
} else {
beamSchema =
AvroUtils.toBeamSchema(
new org.apache.avro.Schema.Parser().parse(checkArgumentNotNull(inputSchema)));
valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
}

KafkaIO.Read<byte[], byte[]> kafkaRead =
KafkaIO.readBytes()
.withConsumerConfigUpdates(consumerConfigs)
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withTopic(configuration.getTopic())
.withBootstrapServers(configuration.getBootstrapServers());
Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
if (maxReadTimeSeconds != null) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
}

PCollection<byte[]> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

Schema errorSchema = ErrorHandling.errorSchemaBytes();
PCollectionTuple outputTuple =
kafkaValues.apply(
ParDo.of(
new ErrorFn(
"Kafka-read-error-counter", valueMapper, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

PCollectionRowTuple outputRows =
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));

PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows =
outputRows.and(
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput);
}
return outputRows;
}
}

public static class ErrorFn extends DoFn<byte[], Row> {
private final SerializableFunction<byte[], Row> valueMapper;
private final Counter errorCounter;
Expand Down
Loading
Loading