diff --git a/sdks/java/extensions/protobuf/build.gradle b/sdks/java/extensions/protobuf/build.gradle index 568d4f220867..1582492c293e 100644 --- a/sdks/java/extensions/protobuf/build.gradle +++ b/sdks/java/extensions/protobuf/build.gradle @@ -39,6 +39,8 @@ dependencies { implementation library.java.slf4j_api implementation project(path: ":sdks:java:core", configuration: "shadow") implementation library.java.protobuf_java + implementation("com.squareup.wire:wire-schema-jvm:4.9.3") + implementation("io.apicurio:apicurio-registry-protobuf-schema-utilities:3.0.0.M2") testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation library.java.junit testRuntimeOnly library.java.slf4j_jdk14 diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtils.java index f156fed0f38c..02419ec0f619 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtils.java @@ -24,6 +24,10 @@ import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; import com.google.protobuf.InvalidProtocolBufferException; +import com.squareup.wire.schema.Location; +import com.squareup.wire.schema.internal.parser.ProtoFileElement; +import com.squareup.wire.schema.internal.parser.ProtoParser; +import io.apicurio.registry.utils.protobuf.schema.FileDescriptorUtils; import java.io.IOException; import java.io.InputStream; import java.io.Serializable; @@ -55,6 +59,8 @@ public class ProtoByteUtils { private static final Logger LOG = LoggerFactory.getLogger(ProtoByteUtils.class); + private static final Location LOCATION = Location.get(""); + /** * Retrieves a Beam Schema from a Protocol Buffer message. * @@ -68,6 +74,68 @@ public static Schema getBeamSchemaFromProto(String fileDescriptorPath, String me return ProtoDynamicMessageSchema.forDescriptor(protoDomain, messageName).getSchema(); } + /** + * Parses the given Protocol Buffers schema string, retrieves the Descriptor for the specified + * message name, and constructs a Beam Schema from it. + * + * @param schemaString The Protocol Buffers schema string. + * @param messageName The name of the message type for which the Beam Schema is desired. + * @return The Beam Schema constructed from the specified Protocol Buffers schema. + * @throws RuntimeException If there is an error during parsing, descriptor retrieval, or schema + * construction. + */ + public static Schema getBeamSchemaFromProtoSchema(String schemaString, String messageName) { + Descriptors.Descriptor descriptor = getDescriptorFromProtoSchema(schemaString, messageName); + return ProtoDynamicMessageSchema.forDescriptor(ProtoDomain.buildFrom(descriptor), descriptor) + .getSchema(); + } + + /** + * Parses the given Protocol Buffers schema string, retrieves the FileDescriptor, and returns the + * Descriptor for the specified message name. + * + * @param schemaString The Protocol Buffers schema string. + * @param messageName The name of the message type for which the descriptor is desired. + * @return The Descriptor for the specified message name. + * @throws RuntimeException If there is an error during parsing or descriptor validation. + */ + private static Descriptors.Descriptor getDescriptorFromProtoSchema( + final String schemaString, final String messageName) { + ProtoFileElement result = ProtoParser.Companion.parse(LOCATION, schemaString); + try { + Descriptors.FileDescriptor fileDescriptor = + FileDescriptorUtils.protoFileToFileDescriptor(result); + return fileDescriptor.findMessageTypeByName(messageName); + } catch (Descriptors.DescriptorValidationException e) { + throw new RuntimeException(e); + } + } + + public static SerializableFunction getProtoBytesToRowFromSchemaFunction( + String schemaString, String messageName) { + + Descriptors.Descriptor descriptor = getDescriptorFromProtoSchema(schemaString, messageName); + + ProtoDynamicMessageSchema protoDynamicMessageSchema = + ProtoDynamicMessageSchema.forDescriptor(ProtoDomain.buildFrom(descriptor), descriptor); + return new SimpleFunction() { + @Override + public Row apply(byte[] input) { + try { + Descriptors.Descriptor descriptorFunction = + getDescriptorFromProtoSchema(schemaString, messageName); + DynamicMessage dynamicMessage = DynamicMessage.parseFrom(descriptorFunction, input); + SerializableFunction res = + protoDynamicMessageSchema.getToRowFunction(); + return res.apply(dynamicMessage); + } catch (InvalidProtocolBufferException e) { + LOG.error("Error parsing to DynamicMessage", e); + throw new RuntimeException(e); + } + } + }; + } + public static SerializableFunction getProtoBytesToRowFunction( String fileDescriptorPath, String messageName) { @@ -96,6 +164,23 @@ public Row apply(byte[] input) { }; } + public static SerializableFunction getRowToProtoBytesFromSchema( + String schemaString, String messageName) { + + Descriptors.Descriptor descriptor = getDescriptorFromProtoSchema(schemaString, messageName); + + ProtoDynamicMessageSchema protoDynamicMessageSchema = + ProtoDynamicMessageSchema.forDescriptor(ProtoDomain.buildFrom(descriptor), descriptor); + return new SimpleFunction() { + @Override + public byte[] apply(Row input) { + SerializableFunction res = + protoDynamicMessageSchema.getFromRowFunction(); + return res.apply(input).toByteArray(); + } + }; + } + public static SerializableFunction getRowToProtoBytes( String fileDescriptorPath, String messageName) { ProtoSchemaInfo dynamicProtoDomain = getProtoDomain(fileDescriptorPath); diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtilsTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtilsTest.java index 2a4cb4b5d5fb..04bcde6a0fe0 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtilsTest.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtilsTest.java @@ -29,6 +29,25 @@ @RunWith(JUnit4.class) public class ProtoByteUtilsTest { + private static final String PROTO_STRING_SCHEMA = + "syntax = \"proto3\";\n" + + "\n" + + "message MyMessage {\n" + + " int32 id = 1;\n" + + " string name = 2;\n" + + " bool active = 3;\n" + + "\n" + + " // Nested field\n" + + " message Address {\n" + + " string street = 1;\n" + + " string city = 2;\n" + + " string state = 3;\n" + + " string zip_code = 4;\n" + + " }\n" + + "\n" + + " Address address = 4;\n" + + "}"; + private static final String DESCRIPTOR_PATH = Objects.requireNonNull( ProtoByteUtilsTest.class.getResource( @@ -59,6 +78,12 @@ public void testProtoSchemaToBeamSchema() { Assert.assertEquals(schema.getFieldNames(), SCHEMA.getFieldNames()); } + @Test + public void testProtoSchemaStringToBeamSchema() { + Schema schema = ProtoByteUtils.getBeamSchemaFromProtoSchema(PROTO_STRING_SCHEMA, "MyMessage"); + Assert.assertEquals(schema.getFieldNames(), SCHEMA.getFieldNames()); + } + @Test public void testProtoBytesToRowFunctionGenerateSerializableFunction() { SerializableFunction protoBytesToRowFunction = @@ -66,6 +91,13 @@ public void testProtoBytesToRowFunctionGenerateSerializableFunction() { Assert.assertNotNull(protoBytesToRowFunction); } + @Test + public void testProtoBytesToRowSchemaStringGenerateSerializableFunction() { + SerializableFunction protoBytesToRowFunction = + ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(PROTO_STRING_SCHEMA, "MyMessage"); + Assert.assertNotNull(protoBytesToRowFunction); + } + @Test(expected = java.lang.RuntimeException.class) public void testProtoBytesToRowFunctionReturnsRowFailure() { // Create a proto bytes to row function @@ -95,4 +127,21 @@ public void testRowToProtoFunction() { Assert.assertNotNull( ProtoByteUtils.getRowToProtoBytes(DESCRIPTOR_PATH, MESSAGE_NAME).apply(row)); } + + @Test + public void testRowToProtoSchemaFunction() { + Row row = + Row.withSchema(SCHEMA) + .withFieldValue("id", 1234) + .withFieldValue("name", "Doe") + .withFieldValue("active", false) + .withFieldValue("address.city", "seattle") + .withFieldValue("address.street", "fake street") + .withFieldValue("address.zip_code", "TO-1234") + .withFieldValue("address.state", "wa") + .build(); + + Assert.assertNotNull( + ProtoByteUtils.getRowToProtoBytesFromSchema(PROTO_STRING_SCHEMA, "MyMessage").apply(row)); + } } diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index 86e43d6de6aa..ea53a4d60dbd 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -89,6 +89,7 @@ dependencies { testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation project(":sdks:java:io:synthetic") testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration") + testImplementation project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:io:common", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:testing:test-utils", configuration: "testRuntimeMigration") // For testing Cross-language transforms diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java index 2fa365b1c7f3..d95c49894a2c 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java @@ -51,6 +51,29 @@ public void validate() { final String dataFormat = this.getFormat(); assert dataFormat == null || VALID_DATA_FORMATS.contains(dataFormat) : "Valid data formats are " + VALID_DATA_FORMATS; + + final String inputSchema = this.getSchema(); + final String messageName = this.getMessageName(); + final String fileDescriptorPath = this.getFileDescriptorPath(); + final String confluentSchemaRegUrl = this.getConfluentSchemaRegistryUrl(); + final String confluentSchemaRegSubject = this.getConfluentSchemaRegistrySubject(); + + if (confluentSchemaRegUrl != null) { + assert confluentSchemaRegSubject != null + : "To read from Kafka, a schema must be provided directly or though Confluent " + + "Schema Registry. Make sure you are providing one of these parameters."; + } else if (dataFormat != null && dataFormat.equals("RAW")) { + assert inputSchema == null : "To read from Kafka in RAW format, you can't provide a schema."; + } else if (dataFormat != null && dataFormat.equals("JSON")) { + assert inputSchema != null : "To read from Kafka in JSON format, you must provide a schema."; + } else if (dataFormat != null && dataFormat.equals("PROTO")) { + assert messageName != null + : "To read from Kafka in PROTO format, messageName must be provided."; + assert fileDescriptorPath != null || inputSchema != null + : "To read from Kafka in PROTO format, fileDescriptorPath or schema must be provided."; + } else { + assert inputSchema != null : "To read from Kafka in AVRO format, you must provide a schema."; + } } /** Instantiates a {@link KafkaReadSchemaTransformConfiguration.Builder} instance. */ diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java index 996976ee9a75..10a347929ee0 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java @@ -17,8 +17,6 @@ */ package org.apache.beam.sdk.io.kafka; -import static org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformConfiguration.VALID_DATA_FORMATS; - import com.google.auto.service.AutoService; import java.io.FileOutputStream; import java.io.IOException; @@ -32,7 +30,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; @@ -61,7 +58,6 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.kafka.clients.consumer.Consumer; @@ -98,10 +94,13 @@ protected Class configurationClass() { return KafkaReadSchemaTransformConfiguration.class; } + @SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) + }) @Override protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configuration) { final String inputSchema = configuration.getSchema(); - final Integer groupId = configuration.hashCode() % Integer.MAX_VALUE; + final int groupId = configuration.hashCode() % Integer.MAX_VALUE; final String autoOffsetReset = MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), "latest"); @@ -115,101 +114,17 @@ protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configurati String format = configuration.getFormat(); boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); - String descriptorPath = configuration.getFileDescriptorPath(); - String messageName = configuration.getMessageName(); - - if ((format != null && VALID_DATA_FORMATS.contains(format)) - || (!Strings.isNullOrEmpty(inputSchema) && !Objects.equals(format, "RAW")) - || (Objects.equals(format, "PROTO") - && !Strings.isNullOrEmpty(descriptorPath) - && !Strings.isNullOrEmpty(messageName))) { - SerializableFunction valueMapper; - Schema beamSchema; - if (format != null && format.equals("RAW")) { - if (inputSchema != null) { - throw new IllegalArgumentException( - "To read from Kafka in RAW format, you can't provide a schema."); - } - beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); - valueMapper = getRawBytesToRowFunction(beamSchema); - } else if (format != null && format.equals("PROTO")) { - if (descriptorPath == null || messageName == null) { - throw new IllegalArgumentException( - "Expecting both descriptorPath and messageName to be non-null."); - } - valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(descriptorPath, messageName); - beamSchema = ProtoByteUtils.getBeamSchemaFromProto(descriptorPath, messageName); - } else { - assert Strings.isNullOrEmpty(configuration.getConfluentSchemaRegistryUrl()) - : "To read from Kafka, a schema must be provided directly or though Confluent " - + "Schema Registry, but not both."; - if (inputSchema == null) { - throw new IllegalArgumentException( - "To read from Kafka in JSON or AVRO format, you must provide a schema."); - } - beamSchema = - Objects.equals(format, "JSON") - ? JsonUtils.beamSchemaFromJsonSchema(inputSchema) - : AvroUtils.toBeamSchema(new org.apache.avro.Schema.Parser().parse(inputSchema)); - valueMapper = - Objects.equals(format, "JSON") - ? JsonUtils.getJsonBytesToRowFunction(beamSchema) - : AvroUtils.getAvroBytesToRowFunction(beamSchema); - } - return new SchemaTransform() { - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - KafkaIO.Read kafkaRead = - KafkaIO.readBytes() - .withConsumerConfigUpdates(consumerConfigs) - .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores()) - .withTopic(configuration.getTopic()) - .withBootstrapServers(configuration.getBootstrapServers()); - if (isTest) { - kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs)); - } - PCollection kafkaValues = - input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create()); + SerializableFunction valueMapper; + Schema beamSchema; - Schema errorSchema = ErrorHandling.errorSchemaBytes(); - PCollectionTuple outputTuple = - kafkaValues.apply( - ParDo.of( - new ErrorFn( - "Kafka-read-error-counter", valueMapper, errorSchema, handleErrors)) - .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); - - PCollectionRowTuple outputRows = - PCollectionRowTuple.of( - "output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema)); - - PCollection errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); - if (handleErrors) { - ErrorHandling errorHandling = configuration.getErrorHandling(); - if (errorHandling == null) { - throw new IllegalArgumentException("You must specify an error handling option."); - } - outputRows = outputRows.and(errorHandling.getOutput(), errorOutput); - } - return outputRows; - } - }; - } else { - assert !Strings.isNullOrEmpty(configuration.getConfluentSchemaRegistryUrl()) - : "To read from Kafka, a schema must be provided directly or though Confluent " - + "Schema Registry. Neither seems to have been provided."; + String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl(); + if (confluentSchemaRegUrl != null) { return new SchemaTransform() { @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { - final String confluentSchemaRegUrl = configuration.getConfluentSchemaRegistryUrl(); final String confluentSchemaRegSubject = configuration.getConfluentSchemaRegistrySubject(); - if (confluentSchemaRegUrl == null || confluentSchemaRegSubject == null) { - throw new IllegalArgumentException( - "To read from Kafka, a schema must be provided directly or though Confluent " - + "Schema Registry. Make sure you are providing one of these parameters."); - } KafkaIO.Read kafkaRead = KafkaIO.read() .withTopic(configuration.getTopic()) @@ -234,6 +149,62 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } }; } + + if (format.equals("RAW")) { + beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); + valueMapper = getRawBytesToRowFunction(beamSchema); + } else if (format.equals("PROTO")) { + String fileDescriptorPath = configuration.getFileDescriptorPath(); + String messageName = configuration.getMessageName(); + if (fileDescriptorPath != null) { + beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName); + valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName); + } else { + beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema, messageName); + valueMapper = ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName); + } + } else if (format.equals("JSON")) { + beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema); + valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema); + } else { + beamSchema = AvroUtils.toBeamSchema(new org.apache.avro.Schema.Parser().parse(inputSchema)); + valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema); + } + + return new SchemaTransform() { + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + KafkaIO.Read kafkaRead = + KafkaIO.readBytes() + .withConsumerConfigUpdates(consumerConfigs) + .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores()) + .withTopic(configuration.getTopic()) + .withBootstrapServers(configuration.getBootstrapServers()); + if (isTest) { + kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs)); + } + + PCollection kafkaValues = + input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create()); + + Schema errorSchema = ErrorHandling.errorSchemaBytes(); + PCollectionTuple outputTuple = + kafkaValues.apply( + ParDo.of( + new ErrorFn( + "Kafka-read-error-counter", valueMapper, errorSchema, handleErrors)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + + PCollectionRowTuple outputRows = + PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema)); + + PCollection errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); + if (handleErrors) { + outputRows = outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput); + } + return outputRows; + } + }; } public static SerializableFunction getRawBytesToRowFunction(Schema rawSchema) { diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java index 694c3e9f2c14..26f37b790ef8 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java @@ -37,13 +37,14 @@ import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; import org.apache.beam.sdk.schemas.utils.JsonUtils; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.Row; @@ -68,8 +69,6 @@ public class KafkaWriteSchemaTransformProvider public static final TupleTag ERROR_TAG = new TupleTag() {}; public static final TupleTag> OUTPUT_TAG = new TupleTag>() {}; - public static final Schema ERROR_SCHEMA = - Schema.builder().addStringField("error").addNullableByteArrayField("row").build(); private static final Logger LOG = LoggerFactory.getLogger(KafkaWriteSchemaTransformProvider.class); @@ -101,25 +100,38 @@ static final class KafkaWriteSchemaTransform extends SchemaTransform implements } public static class ErrorCounterFn extends DoFn> { - private SerializableFunction toBytesFn; - private Counter errorCounter; + private final SerializableFunction toBytesFn; + private final Counter errorCounter; private Long errorsInBundle = 0L; - - public ErrorCounterFn(String name, SerializableFunction toBytesFn) { + private final boolean handleErrors; + private final Schema errorSchema; + + public ErrorCounterFn( + String name, + SerializableFunction toBytesFn, + Schema errorSchema, + boolean handleErrors) { this.toBytesFn = toBytesFn; - errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, name); + this.errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, name); + this.handleErrors = handleErrors; + this.errorSchema = errorSchema; } @ProcessElement public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) { + KV output = null; try { - receiver.get(OUTPUT_TAG).output(KV.of(new byte[1], toBytesFn.apply(row))); + output = KV.of(new byte[1], toBytesFn.apply(row)); } catch (Exception e) { + if (!handleErrors) { + throw new RuntimeException(e); + } errorsInBundle += 1; LOG.warn("Error while processing the element", e); - receiver - .get(ERROR_TAG) - .output(Row.withSchema(ERROR_SCHEMA).addValues(e.toString(), row.toString()).build()); + receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, row, e)); + } + if (output != null) { + receiver.get(OUTPUT_TAG).output(output); } } @@ -130,6 +142,9 @@ public void finish() { } } + @SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) + }) @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { Schema inputSchema = input.get("input").getSchema(); @@ -139,7 +154,7 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { if (numFields != 1) { throw new IllegalArgumentException("Expecting exactly one field, found " + numFields); } - if (inputSchema.getField(0).getType().equals(Schema.FieldType.BYTES)) { + if (!inputSchema.getField(0).getType().equals(Schema.FieldType.BYTES)) { throw new IllegalArgumentException( "The input schema must have exactly one field of type byte."); } @@ -148,23 +163,38 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { toBytesFn = JsonUtils.getRowToJsonBytesFunction(inputSchema); } else if (configuration.getFormat().equals("PROTO")) { String descriptorPath = configuration.getFileDescriptorPath(); + String schema = configuration.getSchema(); String messageName = configuration.getMessageName(); - if (descriptorPath == null || messageName == null) { + if (messageName == null) { + throw new IllegalArgumentException("Expecting messageName to be non-null."); + } + if (descriptorPath != null && schema != null) { + throw new IllegalArgumentException( + "You must include a descriptorPath or a proto Schema but not both."); + } else if (descriptorPath != null) { + toBytesFn = ProtoByteUtils.getRowToProtoBytes(descriptorPath, messageName); + } else if (schema != null) { + toBytesFn = ProtoByteUtils.getRowToProtoBytesFromSchema(schema, messageName); + } else { throw new IllegalArgumentException( - "Expecting both descriptorPath and messageName to be non-null."); + "At least a descriptorPath or a proto Schema is required."); } - toBytesFn = ProtoByteUtils.getRowToProtoBytes(descriptorPath, messageName); + } else { toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema); } + boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); final Map configOverrides = configuration.getProducerConfigUpdates(); + Schema errorSchema = ErrorHandling.errorSchema(inputSchema); PCollectionTuple outputTuple = input .get("input") .apply( "Map rows to Kafka messages", - ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", toBytesFn)) + ParDo.of( + new ErrorCounterFn( + "Kafka-write-error-counter", toBytesFn, errorSchema, handleErrors)) .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); outputTuple @@ -180,8 +210,11 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { .withKeySerializer(ByteArraySerializer.class) .withValueSerializer(ByteArraySerializer.class)); + // TODO: include output from KafkaIO Write once updated from PDone + PCollection errorOutput = + outputTuple.get(ERROR_TAG).setRowSchema(ErrorHandling.errorSchema(errorSchema)); return PCollectionRowTuple.of( - "errors", outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA)); + handleErrors ? configuration.getErrorHandling().getOutput() : "errors", errorOutput); } } @@ -232,6 +265,18 @@ public abstract static class KafkaWriteSchemaTransformConfiguration implements S + " of servers. | Format: host1:port1,host2:port2,...") public abstract String getBootstrapServers(); + @SchemaFieldDescription( + "A list of key-value pairs that act as configuration parameters for Kafka producers." + + " Most of these configurations will not be needed, but if you need to customize your Kafka producer," + + " you may use this. See a detailed list:" + + " https://docs.confluent.io/platform/current/installation/configuration/producer-configs.html") + @Nullable + public abstract Map getProducerConfigUpdates(); + + @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") + @Nullable + public abstract ErrorHandling getErrorHandling(); + @SchemaFieldDescription( "The path to the Protocol Buffer File Descriptor Set file. This file is used for schema" + " definition and message serialization.") @@ -244,13 +289,8 @@ public abstract static class KafkaWriteSchemaTransformConfiguration implements S @Nullable public abstract String getMessageName(); - @SchemaFieldDescription( - "A list of key-value pairs that act as configuration parameters for Kafka producers." - + " Most of these configurations will not be needed, but if you need to customize your Kafka producer," - + " you may use this. See a detailed list:" - + " https://docs.confluent.io/platform/current/installation/configuration/producer-configs.html") @Nullable - public abstract Map getProducerConfigUpdates(); + public abstract String getSchema(); public static Builder builder() { return new AutoValue_KafkaWriteSchemaTransformProvider_KafkaWriteSchemaTransformConfiguration @@ -265,11 +305,15 @@ public abstract static class Builder { public abstract Builder setBootstrapServers(String bootstrapServers); + public abstract Builder setProducerConfigUpdates(Map producerConfigUpdates); + + public abstract Builder setErrorHandling(ErrorHandling errorHandling); + public abstract Builder setFileDescriptorPath(String fileDescriptorPath); public abstract Builder setMessageName(String messageName); - public abstract Builder setProducerConfigUpdates(Map producerConfigUpdates); + public abstract Builder setSchema(String schema); public abstract KafkaWriteSchemaTransformConfiguration build(); } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java index 27fa18715c32..4f133746b535 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java @@ -43,6 +43,25 @@ public class KafkaReadSchemaTransformProviderTest { + "\"name\":\"FullName\",\"fields\":[{\"name\":\"first\",\"type\":\"string\"}," + "{\"name\":\"last\",\"type\":\"string\"}]}"; + private static final String PROTO_SCHEMA = + "syntax = \"proto3\";\n" + + "\n" + + "message MyMessage {\n" + + " int32 id = 1;\n" + + " string name = 2;\n" + + " bool active = 3;\n" + + "\n" + + " // Nested field\n" + + " message Address {\n" + + " string street = 1;\n" + + " string city = 2;\n" + + " string state = 3;\n" + + " string zip_code = 4;\n" + + " }\n" + + "\n" + + " Address address = 4;\n" + + "}"; + @Test public void testValidConfigurations() { assertThrows( @@ -121,6 +140,7 @@ public void testBuildTransformWithAvroSchema() { (KafkaReadSchemaTransformProvider) providers.get(0); kafkaProvider.from( KafkaReadSchemaTransformConfiguration.builder() + .setFormat("AVRO") .setTopic("anytopic") .setBootstrapServers("anybootstrap") .setSchema(AVRO_SCHEMA) @@ -220,4 +240,48 @@ public void testBuildTransformWithProtoFormatWrongMessageName() { .getPath()) .build())); } + + @Test + public void testBuildTransformWithProtoSchemaFormat() { + ServiceLoader serviceLoader = + ServiceLoader.load(SchemaTransformProvider.class); + List providers = + StreamSupport.stream(serviceLoader.spliterator(), false) + .filter(provider -> provider.getClass() == KafkaReadSchemaTransformProvider.class) + .collect(Collectors.toList()); + KafkaReadSchemaTransformProvider kafkaProvider = + (KafkaReadSchemaTransformProvider) providers.get(0); + + kafkaProvider.from( + KafkaReadSchemaTransformConfiguration.builder() + .setTopic("anytopic") + .setBootstrapServers("anybootstrap") + .setFormat("PROTO") + .setMessageName("MyMessage") + .setSchema(PROTO_SCHEMA) + .build()); + } + + @Test + public void testBuildTransformWithoutProtoSchemaFormat() { + ServiceLoader serviceLoader = + ServiceLoader.load(SchemaTransformProvider.class); + List providers = + StreamSupport.stream(serviceLoader.spliterator(), false) + .filter(provider -> provider.getClass() == KafkaReadSchemaTransformProvider.class) + .collect(Collectors.toList()); + KafkaReadSchemaTransformProvider kafkaProvider = + (KafkaReadSchemaTransformProvider) providers.get(0); + + assertThrows( + NullPointerException.class, + () -> + kafkaProvider.from( + KafkaReadSchemaTransformConfiguration.builder() + .setTopic("anytopic") + .setBootstrapServers("anybootstrap") + .setFormat("PROTO") + .setMessageName("MyMessage") + .build())); + } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java index 20f474790cc7..48d463a8f436 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java @@ -27,6 +27,7 @@ import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils; import org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.ErrorCounterFn; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; import org.apache.beam.sdk.schemas.utils.JsonUtils; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -56,7 +57,6 @@ public class KafkaWriteSchemaTransformProviderTest { private static final Schema BEAM_RAW_SCHEMA = Schema.of(Schema.Field.of("payload", Schema.FieldType.BYTES)); - private static final Schema ERRORSCHEMA = KafkaWriteSchemaTransformProvider.ERROR_SCHEMA; private static final Schema BEAM_PROTO_SCHEMA = Schema.builder() @@ -135,12 +135,14 @@ public void testKafkaErrorFnSuccess() throws Exception { KV.of(new byte[1], "{\"name\":\"c\"}".getBytes("UTF8"))); PCollection input = p.apply(Create.of(ROWS)); + Schema errorSchema = ErrorHandling.errorSchema(BEAMSCHEMA); PCollectionTuple output = input.apply( - ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", valueMapper)) + ParDo.of( + new ErrorCounterFn("Kafka-write-error-counter", valueMapper, errorSchema, true)) .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); - output.get(ERROR_TAG).setRowSchema(ERRORSCHEMA); + output.get(ERROR_TAG).setRowSchema(errorSchema); PAssert.that(output.get(OUTPUT_TAG)).containsInAnyOrder(msg); p.run().waitUntilFinish(); @@ -155,12 +157,15 @@ public void testKafkaErrorFnRawSuccess() throws Exception { KV.of(new byte[1], "c".getBytes("UTF8"))); PCollection input = p.apply(Create.of(RAW_ROWS)); + Schema errorSchema = ErrorHandling.errorSchema(BEAM_RAW_SCHEMA); PCollectionTuple output = input.apply( - ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", valueRawMapper)) + ParDo.of( + new ErrorCounterFn( + "Kafka-write-error-counter", valueRawMapper, errorSchema, true)) .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); - output.get(ERROR_TAG).setRowSchema(ERRORSCHEMA); + output.get(ERROR_TAG).setRowSchema(errorSchema); PAssert.that(output.get(OUTPUT_TAG)).containsInAnyOrder(msg); p.run().waitUntilFinish(); @@ -169,12 +174,15 @@ public void testKafkaErrorFnRawSuccess() throws Exception { @Test public void testKafkaErrorFnProtoSuccess() { PCollection input = p.apply(Create.of(PROTO_ROWS)); + Schema errorSchema = ErrorHandling.errorSchema(BEAM_PROTO_SCHEMA); PCollectionTuple output = input.apply( - ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", protoValueRawMapper)) + ParDo.of( + new ErrorCounterFn( + "Kafka-write-error-counter", protoValueRawMapper, errorSchema, true)) .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); - PAssert.that(output.get(ERROR_TAG).setRowSchema(ERRORSCHEMA)).empty(); + output.get(ERROR_TAG).setRowSchema(errorSchema); p.run().waitUntilFinish(); } } diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index b49d40d5a40b..b617a4cbf285 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -71,6 +71,7 @@ 'producer_config_updates': 'producerConfigUpdates' 'file_descriptor_path': 'fileDescriptorPath' 'message_name': 'messageName' + 'schema': 'schema' underlying_provider: type: beamJar transforms: