diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 8f995a63a10f..35aabbbfd97b 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -2665,7 +2665,7 @@ abstract static class Builder { abstract Builder setProducerFactoryFn( @Nullable SerializableFunction, Producer> fn); - abstract Builder setKeySerializer(Class> serializer); + abstract Builder setKeySerializer(@Nullable Class> serializer); abstract Builder setValueSerializer(Class> serializer); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java index 13f5249a6c3b..693c1371f78c 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java @@ -149,6 +149,10 @@ public static Builder builder() { /** Sets the topic from which to read. */ public abstract String getTopic(); + @SchemaFieldDescription("Upper bound of how long to read from Kafka.") + @Nullable + public abstract Integer getMaxReadTimeSeconds(); + @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") @Nullable public abstract ErrorHandling getErrorHandling(); @@ -179,6 +183,8 @@ public abstract static class Builder { /** Sets the topic from which to read. */ public abstract Builder setTopic(String value); + public abstract Builder setMaxReadTimeSeconds(Integer maxReadTimeSeconds); + public abstract Builder setErrorHandling(ErrorHandling errorHandling); /** Builds a {@link KafkaReadSchemaTransformConfiguration} instance. */ diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java index 13240ea9dc40..b2eeb1a54d1d 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.kafka; +import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; + import com.google.auto.service.AutoService; import java.io.FileOutputStream; import java.io.IOException; @@ -38,7 +40,9 @@ import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.transforms.Convert; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; @@ -56,7 +60,6 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -76,19 +79,6 @@ public class KafkaReadSchemaTransformProvider public static final TupleTag OUTPUT_TAG = new TupleTag() {}; public static final TupleTag ERROR_TAG = new TupleTag() {}; - final Boolean isTest; - final Integer testTimeoutSecs; - - public KafkaReadSchemaTransformProvider() { - this(false, 0); - } - - @VisibleForTesting - KafkaReadSchemaTransformProvider(Boolean isTest, Integer testTimeoutSecs) { - this.isTest = isTest; - this.testTimeoutSecs = testTimeoutSecs; - } - @Override protected Class configurationClass() { return KafkaReadSchemaTransformConfiguration.class; @@ -99,113 +89,7 @@ protected Class configurationClass() { }) @Override protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configuration) { - configuration.validate(); - - final String inputSchema = configuration.getSchema(); - final int groupId = configuration.hashCode() % Integer.MAX_VALUE; - final String autoOffsetReset = - MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), "latest"); - - Map consumerConfigs = - new HashMap<>( - MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new HashMap<>())); - consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-" + groupId); - consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true); - consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100); - consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset); - - String format = configuration.getFormat(); - boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); - - SerializableFunction valueMapper; - Schema beamSchema; - - String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl(); - if (confluentSchemaRegUrl != null) { - return new SchemaTransform() { - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - final String confluentSchemaRegSubject = - configuration.getConfluentSchemaRegistrySubject(); - KafkaIO.Read kafkaRead = - KafkaIO.read() - .withTopic(configuration.getTopic()) - .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores()) - .withBootstrapServers(configuration.getBootstrapServers()) - .withConsumerConfigUpdates(consumerConfigs) - .withKeyDeserializer(ByteArrayDeserializer.class) - .withValueDeserializer( - ConfluentSchemaRegistryDeserializerProvider.of( - confluentSchemaRegUrl, confluentSchemaRegSubject)); - if (isTest) { - kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs)); - } - - PCollection kafkaValues = - input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create()); - - assert kafkaValues.getCoder().getClass() == AvroCoder.class; - AvroCoder coder = (AvroCoder) kafkaValues.getCoder(); - kafkaValues = kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema())); - return PCollectionRowTuple.of("output", kafkaValues.apply(Convert.toRows())); - } - }; - } - if ("RAW".equals(format)) { - beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); - valueMapper = getRawBytesToRowFunction(beamSchema); - } else if ("PROTO".equals(format)) { - String fileDescriptorPath = configuration.getFileDescriptorPath(); - String messageName = configuration.getMessageName(); - if (fileDescriptorPath != null) { - beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName); - valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName); - } else { - beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema, messageName); - valueMapper = ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName); - } - } else if ("JSON".equals(format)) { - beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema); - valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema); - } else { - beamSchema = AvroUtils.toBeamSchema(new org.apache.avro.Schema.Parser().parse(inputSchema)); - valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema); - } - - return new SchemaTransform() { - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - KafkaIO.Read kafkaRead = - KafkaIO.readBytes() - .withConsumerConfigUpdates(consumerConfigs) - .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores()) - .withTopic(configuration.getTopic()) - .withBootstrapServers(configuration.getBootstrapServers()); - if (isTest) { - kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs)); - } - - PCollection kafkaValues = - input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create()); - - Schema errorSchema = ErrorHandling.errorSchemaBytes(); - PCollectionTuple outputTuple = - kafkaValues.apply( - ParDo.of( - new ErrorFn( - "Kafka-read-error-counter", valueMapper, errorSchema, handleErrors)) - .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); - - PCollectionRowTuple outputRows = - PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema)); - - PCollection errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); - if (handleErrors) { - outputRows = outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput); - } - return outputRows; - } - }; + return new KafkaReadSchemaTransform(configuration); } public static SerializableFunction getRawBytesToRowFunction(Schema rawSchema) { @@ -232,6 +116,140 @@ public List outputCollectionNames() { return Arrays.asList("output", "errors"); } + static class KafkaReadSchemaTransform extends SchemaTransform { + private final KafkaReadSchemaTransformConfiguration configuration; + + KafkaReadSchemaTransform(KafkaReadSchemaTransformConfiguration configuration) { + this.configuration = configuration; + } + + Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(KafkaReadSchemaTransformConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + configuration.validate(); + + final String inputSchema = configuration.getSchema(); + final int groupId = configuration.hashCode() % Integer.MAX_VALUE; + final String autoOffsetReset = + MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), "latest"); + + Map consumerConfigs = + new HashMap<>( + MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new HashMap<>())); + consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-" + groupId); + consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true); + consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100); + consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset); + + String format = configuration.getFormat(); + boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); + + SerializableFunction valueMapper; + Schema beamSchema; + + String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl(); + if (confluentSchemaRegUrl != null) { + final String confluentSchemaRegSubject = + checkArgumentNotNull(configuration.getConfluentSchemaRegistrySubject()); + KafkaIO.Read kafkaRead = + KafkaIO.read() + .withTopic(configuration.getTopic()) + .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores()) + .withBootstrapServers(configuration.getBootstrapServers()) + .withConsumerConfigUpdates(consumerConfigs) + .withKeyDeserializer(ByteArrayDeserializer.class) + .withValueDeserializer( + ConfluentSchemaRegistryDeserializerProvider.of( + confluentSchemaRegUrl, confluentSchemaRegSubject)); + Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds(); + if (maxReadTimeSeconds != null) { + kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds)); + } + + PCollection kafkaValues = + input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create()); + + assert kafkaValues.getCoder().getClass() == AvroCoder.class; + AvroCoder coder = (AvroCoder) kafkaValues.getCoder(); + kafkaValues = kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema())); + return PCollectionRowTuple.of("output", kafkaValues.apply(Convert.toRows())); + } + + if ("RAW".equals(format)) { + beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); + valueMapper = getRawBytesToRowFunction(beamSchema); + } else if ("PROTO".equals(format)) { + String fileDescriptorPath = configuration.getFileDescriptorPath(); + String messageName = checkArgumentNotNull(configuration.getMessageName()); + if (fileDescriptorPath != null) { + beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName); + valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName); + } else { + beamSchema = + ProtoByteUtils.getBeamSchemaFromProtoSchema( + checkArgumentNotNull(inputSchema), messageName); + valueMapper = + ProtoByteUtils.getProtoBytesToRowFromSchemaFunction( + checkArgumentNotNull(inputSchema), messageName); + } + } else if ("JSON".equals(format)) { + beamSchema = JsonUtils.beamSchemaFromJsonSchema(checkArgumentNotNull(inputSchema)); + valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema); + } else { + beamSchema = + AvroUtils.toBeamSchema( + new org.apache.avro.Schema.Parser().parse(checkArgumentNotNull(inputSchema))); + valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema); + } + + KafkaIO.Read kafkaRead = + KafkaIO.readBytes() + .withConsumerConfigUpdates(consumerConfigs) + .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores()) + .withTopic(configuration.getTopic()) + .withBootstrapServers(configuration.getBootstrapServers()); + Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds(); + if (maxReadTimeSeconds != null) { + kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds)); + } + + PCollection kafkaValues = + input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create()); + + Schema errorSchema = ErrorHandling.errorSchemaBytes(); + PCollectionTuple outputTuple = + kafkaValues.apply( + ParDo.of( + new ErrorFn( + "Kafka-read-error-counter", valueMapper, errorSchema, handleErrors)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + + PCollectionRowTuple outputRows = + PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema)); + + PCollection errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); + if (handleErrors) { + outputRows = + outputRows.and( + checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput); + } + return outputRows; + } + } + public static class ErrorFn extends DoFn { private final SerializableFunction valueMapper; private final Counter errorCounter; diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java new file mode 100644 index 000000000000..4b83e2b6f558 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformProvider.KafkaReadSchemaTransform; +import static org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform; +import static org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +public class KafkaSchemaTransformTranslation { + static class KafkaReadSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new KafkaReadSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(KafkaReadSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(KafkaReadSchemaTransform.class, new KafkaReadSchemaTransformTranslator()) + .build(); + } + } + + static class KafkaWriteSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new KafkaWriteSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(KafkaWriteSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class WriteRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(KafkaWriteSchemaTransform.class, new KafkaWriteSchemaTransformTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java index 26f37b790ef8..09b338492b47 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java @@ -31,7 +31,9 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; @@ -99,6 +101,20 @@ static final class KafkaWriteSchemaTransform extends SchemaTransform implements this.configuration = configuration; } + Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(KafkaWriteSchemaTransformConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + public static class ErrorCounterFn extends DoFn> { private final SerializableFunction toBytesFn; private final Counter errorCounter; diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java index ab6ac52e318d..4d38636892c2 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.io.kafka.ReadFromKafkaDoFnTest.FailingDeserializer; import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource; import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions; +import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.ExperimentalOptions; @@ -607,18 +608,18 @@ public void testKafkaWithDelayedStopReadingFunction() { private static final int FIVE_MINUTES_IN_MS = 5 * 60 * 1000; @Test(timeout = FIVE_MINUTES_IN_MS) - public void testKafkaViaSchemaTransformJson() { - runReadWriteKafkaViaSchemaTransforms( + public void testKafkaViaManagedSchemaTransformJson() { + runReadWriteKafkaViaManagedSchemaTransforms( "JSON", SCHEMA_IN_JSON, JsonUtils.beamSchemaFromJsonSchema(SCHEMA_IN_JSON)); } @Test(timeout = FIVE_MINUTES_IN_MS) - public void testKafkaViaSchemaTransformAvro() { - runReadWriteKafkaViaSchemaTransforms( + public void testKafkaViaManagedSchemaTransformAvro() { + runReadWriteKafkaViaManagedSchemaTransforms( "AVRO", AvroUtils.toAvroSchema(KAFKA_TOPIC_SCHEMA).toString(), KAFKA_TOPIC_SCHEMA); } - public void runReadWriteKafkaViaSchemaTransforms( + public void runReadWriteKafkaViaManagedSchemaTransforms( String format, String schemaDefinition, Schema beamSchema) { String topicName = options.getKafkaTopic() + "-schema-transform" + UUID.randomUUID(); PCollectionRowTuple.of( @@ -646,13 +647,12 @@ public void runReadWriteKafkaViaSchemaTransforms( .setRowSchema(beamSchema)) .apply( "Write to Kafka", - new KafkaWriteSchemaTransformProvider() - .from( - KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransformConfiguration - .builder() - .setTopic(topicName) - .setBootstrapServers(options.getKafkaBootstrapServerAddresses()) - .setFormat(format) + Managed.write(Managed.KAFKA) + .withConfig( + ImmutableMap.builder() + .put("topic", topicName) + .put("bootstrap_servers", options.getKafkaBootstrapServerAddresses()) + .put("format", format) .build())); PAssert.that( @@ -661,15 +661,18 @@ public void runReadWriteKafkaViaSchemaTransforms( "Read from unbounded Kafka", // A timeout of 30s for local, container-based tests, and 2 minutes for // real-kafka tests. - new KafkaReadSchemaTransformProvider( - true, options.isWithTestcontainers() ? 30 : 120) - .from( - KafkaReadSchemaTransformConfiguration.builder() - .setFormat(format) - .setAutoOffsetResetConfig("earliest") - .setSchema(schemaDefinition) - .setTopic(topicName) - .setBootstrapServers(options.getKafkaBootstrapServerAddresses()) + Managed.read(Managed.KAFKA) + .withConfig( + ImmutableMap.builder() + .put("format", format) + .put("auto_offset_reset_config", "earliest") + .put("schema", schemaDefinition) + .put("topic", topicName) + .put( + "bootstrap_servers", options.getKafkaBootstrapServerAddresses()) + .put( + "max_read_time_seconds", + options.isWithTestcontainers() ? 30 : 120) .build())) .get("output")) .containsInAnyOrder( diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java index dfe062e1eef4..19c336e1d24e 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java @@ -34,9 +34,11 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.managed.ManagedTransformConstants; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.utils.YamlUtils; import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; @@ -131,7 +133,8 @@ public void testFindTransformAndMakeItWork() { "confluent_schema_registry_url", "error_handling", "file_descriptor_path", - "message_name"), + "message_name", + "max_read_time_seconds"), kafkaProvider.configurationSchema().getFields().stream() .map(field -> field.getName()) .collect(Collectors.toSet())); @@ -232,22 +235,23 @@ public void testBuildTransformWithProtoFormatWrongMessageName() { .collect(Collectors.toList()); KafkaReadSchemaTransformProvider kafkaProvider = (KafkaReadSchemaTransformProvider) providers.get(0); + SchemaTransform transform = + kafkaProvider.from( + KafkaReadSchemaTransformConfiguration.builder() + .setTopic("anytopic") + .setBootstrapServers("anybootstrap") + .setFormat("PROTO") + .setMessageName("MyOtherMessage") + .setFileDescriptorPath( + Objects.requireNonNull( + getClass() + .getResource("/proto_byte/file_descriptor/proto_byte_utils.pb")) + .getPath()) + .build()); assertThrows( NullPointerException.class, - () -> - kafkaProvider.from( - KafkaReadSchemaTransformConfiguration.builder() - .setTopic("anytopic") - .setBootstrapServers("anybootstrap") - .setFormat("PROTO") - .setMessageName("MyOtherMessage") - .setFileDescriptorPath( - Objects.requireNonNull( - getClass() - .getResource("/proto_byte/file_descriptor/proto_byte_utils.pb")) - .getPath()) - .build())); + () -> transform.expand(PCollectionRowTuple.empty(Pipeline.create()))); } @Test @@ -281,17 +285,18 @@ public void testBuildTransformWithoutProtoSchemaFormat() { .collect(Collectors.toList()); KafkaReadSchemaTransformProvider kafkaProvider = (KafkaReadSchemaTransformProvider) providers.get(0); + SchemaTransform transform = + kafkaProvider.from( + KafkaReadSchemaTransformConfiguration.builder() + .setTopic("anytopic") + .setBootstrapServers("anybootstrap") + .setFormat("PROTO") + .setMessageName("MyMessage") + .build()); assertThrows( IllegalArgumentException.class, - () -> - kafkaProvider.from( - KafkaReadSchemaTransformConfiguration.builder() - .setTopic("anytopic") - .setBootstrapServers("anybootstrap") - .setFormat("PROTO") - .setMessageName("MyMessage") - .build())); + () -> transform.expand(PCollectionRowTuple.empty(Pipeline.create()))); } @Test diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java new file mode 100644 index 000000000000..b297227bb7aa --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; +import static org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformProvider.KafkaReadSchemaTransform; +import static org.apache.beam.sdk.io.kafka.KafkaSchemaTransformTranslation.KafkaReadSchemaTransformTranslator; +import static org.apache.beam.sdk.io.kafka.KafkaSchemaTransformTranslation.KafkaWriteSchemaTransformTranslator; +import static org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.BeamUrns; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; + +public class KafkaSchemaTransformTranslationTest { + @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); + + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + static final KafkaWriteSchemaTransformProvider WRITE_PROVIDER = + new KafkaWriteSchemaTransformProvider(); + static final KafkaReadSchemaTransformProvider READ_PROVIDER = + new KafkaReadSchemaTransformProvider(); + + static final Row READ_CONFIG = + Row.withSchema(READ_PROVIDER.configurationSchema()) + .withFieldValue("format", "RAW") + .withFieldValue("topic", "test_topic") + .withFieldValue("bootstrap_servers", "host:port") + .withFieldValue("confluent_schema_registry_url", null) + .withFieldValue("confluent_schema_registry_subject", null) + .withFieldValue("schema", null) + .withFieldValue("file_descriptor_path", "testPath") + .withFieldValue("message_name", "test_message") + .withFieldValue("auto_offset_reset_config", "earliest") + .withFieldValue("consumer_config_updates", ImmutableMap.builder().build()) + .withFieldValue("error_handling", null) + .build(); + + static final Row WRITE_CONFIG = + Row.withSchema(WRITE_PROVIDER.configurationSchema()) + .withFieldValue("format", "RAW") + .withFieldValue("topic", "test_topic") + .withFieldValue("bootstrap_servers", "host:port") + .withFieldValue("producer_config_updates", ImmutableMap.builder().build()) + .withFieldValue("error_handling", null) + .withFieldValue("file_descriptor_path", "testPath") + .withFieldValue("message_name", "test_message") + .withFieldValue("schema", "test_schema") + .build(); + + @Test + public void testRecreateWriteTransformFromRow() { + KafkaWriteSchemaTransform writeTransform = + (KafkaWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + + KafkaWriteSchemaTransformTranslator translator = new KafkaWriteSchemaTransformTranslator(); + Row translatedRow = translator.toConfigRow(writeTransform); + + KafkaWriteSchemaTransform writeTransformFromRow = + translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromRow.getConfigurationRow()); + } + + @Test + public void testWriteTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + Schema inputSchema = Schema.builder().addByteArrayField("b").build(); + PCollection input = + p.apply( + Create.of( + Collections.singletonList( + Row.withSchema(inputSchema).addValue(new byte[] {1, 2, 3}).build()))) + .setRowSchema(inputSchema); + + KafkaWriteSchemaTransform writeTransform = + (KafkaWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + PCollectionRowTuple.of("input", input).apply(writeTransform); + + // Then translate the pipeline to a proto and extract KafkaWriteSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(WRITE_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, writeTransformProto.size()); + RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + + assertEquals(WRITE_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the KafkaWriteSchemaTransform + KafkaWriteSchemaTransformTranslator translator = new KafkaWriteSchemaTransformTranslator(); + KafkaWriteSchemaTransform writeTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromSpec.getConfigurationRow()); + } + + @Test + public void testReCreateReadTransformFromRow() { + // setting a subset of fields here. + KafkaReadSchemaTransform readTransform = + (KafkaReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + KafkaReadSchemaTransformTranslator translator = new KafkaReadSchemaTransformTranslator(); + Row row = translator.toConfigRow(readTransform); + + KafkaReadSchemaTransform readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromRow.getConfigurationRow()); + } + + @Test + public void testReadTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + + KafkaReadSchemaTransform readTransform = + (KafkaReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + PCollectionRowTuple.empty(p).apply(readTransform); + + // Then translate the pipeline to a proto and extract KafkaReadSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List readTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(READ_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, readTransformProto.size()); + RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + assertEquals(READ_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the KafkaReadSchemaTransform + KafkaReadSchemaTransformTranslator translator = new KafkaReadSchemaTransformTranslator(); + KafkaReadSchemaTransform readTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromSpec.getConfigurationRow()); + } +}