diff --git a/sdks/java/extensions/protobuf/build.gradle b/sdks/java/extensions/protobuf/build.gradle index 2696f8886ddd..568d4f220867 100644 --- a/sdks/java/extensions/protobuf/build.gradle +++ b/sdks/java/extensions/protobuf/build.gradle @@ -35,6 +35,8 @@ ext.summary = "Add support to Apache Beam for Google Protobuf." dependencies { implementation library.java.byte_buddy implementation library.java.vendored_guava_32_1_2_jre + implementation library.java.commons_compress + implementation library.java.slf4j_api implementation project(path: ":sdks:java:core", configuration: "shadow") implementation library.java.protobuf_java testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtils.java new file mode 100644 index 000000000000..f156fed0f38c --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtils.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import java.io.IOException; +import java.io.InputStream; +import java.io.Serializable; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.List; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.commons.compress.utils.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utility class for working with Protocol Buffer (Proto) data in the context of Apache Beam. This + * class provides methods to retrieve Beam Schemas from Proto messages, convert Proto bytes to Beam + * Rows, and vice versa. It also includes utilities for handling Protocol Buffer schemas and related + * file operations. + * + *

Users can utilize the methods in this class to facilitate the integration of Proto data + * processing within Apache Beam pipelines, allowing for the seamless transformation of Proto + * messages to Beam Rows and vice versa. + */ +public class ProtoByteUtils { + + private static final Logger LOG = LoggerFactory.getLogger(ProtoByteUtils.class); + + /** + * Retrieves a Beam Schema from a Protocol Buffer message. + * + * @param fileDescriptorPath The path to the File Descriptor Set file. + * @param messageName The name of the Protocol Buffer message. + * @return The Beam Schema representing the Protocol Buffer message. + */ + public static Schema getBeamSchemaFromProto(String fileDescriptorPath, String messageName) { + ProtoSchemaInfo dpd = getProtoDomain(fileDescriptorPath); + ProtoDomain protoDomain = dpd.getProtoDomain(); + return ProtoDynamicMessageSchema.forDescriptor(protoDomain, messageName).getSchema(); + } + + public static SerializableFunction getProtoBytesToRowFunction( + String fileDescriptorPath, String messageName) { + + ProtoSchemaInfo dynamicProtoDomain = getProtoDomain(fileDescriptorPath); + ProtoDomain protoDomain = dynamicProtoDomain.getProtoDomain(); + @SuppressWarnings("unchecked") + ProtoDynamicMessageSchema protoDynamicMessageSchema = + ProtoDynamicMessageSchema.forDescriptor(protoDomain, messageName); + return new SimpleFunction() { + @Override + public Row apply(byte[] input) { + try { + final Descriptors.Descriptor descriptor = + protoDomain + .getFileDescriptor(dynamicProtoDomain.getFileName()) + .findMessageTypeByName(messageName); + DynamicMessage dynamicMessage = DynamicMessage.parseFrom(descriptor, input); + SerializableFunction res = + protoDynamicMessageSchema.getToRowFunction(); + return res.apply(dynamicMessage); + } catch (InvalidProtocolBufferException e) { + LOG.error("Error parsing to DynamicMessage", e); + throw new RuntimeException(e); + } + } + }; + } + + public static SerializableFunction getRowToProtoBytes( + String fileDescriptorPath, String messageName) { + ProtoSchemaInfo dynamicProtoDomain = getProtoDomain(fileDescriptorPath); + ProtoDomain protoDomain = dynamicProtoDomain.getProtoDomain(); + @SuppressWarnings("unchecked") + ProtoDynamicMessageSchema protoDynamicMessageSchema = + ProtoDynamicMessageSchema.forDescriptor(protoDomain, messageName); + + return new SimpleFunction() { + @Override + public byte[] apply(Row input) { + SerializableFunction res = + protoDynamicMessageSchema.getFromRowFunction(); + return res.apply(input).toByteArray(); + } + }; + } + + /** + * Retrieves a ProtoSchemaInfo containing schema information for the specified Protocol Buffer + * file. + * + * @param fileDescriptorPath The path to the File Descriptor Set file. + * @return ProtoSchemaInfo containing the associated ProtoDomain and File Name. + * @throws RuntimeException if an error occurs during schema retrieval. + */ + private static ProtoSchemaInfo getProtoDomain(String fileDescriptorPath) { + byte[] from = getFileAsBytes(fileDescriptorPath); + try { + DescriptorProtos.FileDescriptorSet descriptorSet = + DescriptorProtos.FileDescriptorSet.parseFrom(from); + return new ProtoSchemaInfo( + descriptorSet.getFile(0).getName(), ProtoDomain.buildFrom(descriptorSet)); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } + + /** + * Reads the contents of a file specified by its path and returns them as a byte array. + * + * @param fileDescriptorPath The path to the file to read. + * @return Byte array containing the file contents. + * @throws RuntimeException if an error occurs during file reading. + */ + private static byte[] getFileAsBytes(String fileDescriptorPath) { + ReadableByteChannel channel = getFileByteChannel(fileDescriptorPath); + try (InputStream inputStream = Channels.newInputStream(channel)) { + return IOUtils.toByteArray(inputStream); + } catch (IOException e) { + throw new RuntimeException("Error when reading: " + fileDescriptorPath, e); + } + } + + /** + * Retrieves a ReadableByteChannel for a file specified by its path. + * + * @param filePath The path to the file to obtain a ReadableByteChannel for. + * @return ReadableByteChannel for the specified file. + * @throws RuntimeException if an error occurs while finding or opening the file. + */ + private static ReadableByteChannel getFileByteChannel(String filePath) { + try { + MatchResult result = FileSystems.match(filePath); + checkArgument( + result.status() == MatchResult.Status.OK && !result.metadata().isEmpty(), + "Failed to match any files with the pattern: " + filePath); + + List rId = + result.metadata().stream().map(MatchResult.Metadata::resourceId).collect(toList()); + + checkArgument(rId.size() == 1, "Expected exactly 1 file, but got " + rId.size() + " files."); + return FileSystems.open(rId.get(0)); + } catch (IOException e) { + throw new RuntimeException("Error when finding: " + filePath, e); + } + } + + /** + * Represents metadata associated with a Protocol Buffer schema, including the File Name and + * ProtoDomain. + */ + static class ProtoSchemaInfo implements Serializable { + private String fileName; + private ProtoDomain protoDomain; + + /** + * Constructs a ProtoSchemaInfo with the specified File Name and ProtoDomain. + * + * @param fileName The name of the associated Protocol Buffer file. + * @param protoDomain The ProtoDomain containing schema information. + */ + public ProtoSchemaInfo(String fileName, ProtoDomain protoDomain) { + this.fileName = fileName; + this.protoDomain = protoDomain; + } + + /** + * Sets the ProtoDomain associated with this ProtoSchemaInfo. + * + * @param protoDomain The ProtoDomain to set. + */ + @SuppressWarnings("unused") + public void setProtoDomain(ProtoDomain protoDomain) { + this.protoDomain = protoDomain; + } + + /** + * Gets the ProtoDomain associated with this ProtoSchemaInfo. + * + * @return The ProtoDomain containing schema information. + */ + public ProtoDomain getProtoDomain() { + return protoDomain; + } + + /** + * Gets the File Name associated with this ProtoSchemaInfo. + * + * @return The name of the associated Protocol Buffer file. + */ + public String getFileName() { + return fileName; + } + + /** + * Sets the File Name associated with this ProtoSchemaInfo. + * + * @param fileName The name of the Protocol Buffer file to set. + */ + public void setFileName(String fileName) { + this.fileName = fileName; + } + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtilsTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtilsTest.java new file mode 100644 index 000000000000..2a4cb4b5d5fb --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteUtilsTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import java.util.Objects; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ProtoByteUtilsTest { + + private static final String DESCRIPTOR_PATH = + Objects.requireNonNull( + ProtoByteUtilsTest.class.getResource( + "/proto_byte/file_descriptor/proto_byte_utils.pb")) + .getPath(); + + private static final String MESSAGE_NAME = "MyMessage"; + + private static final Schema SCHEMA = + Schema.builder() + .addField("id", Schema.FieldType.INT32) + .addField("name", Schema.FieldType.STRING) + .addField("active", Schema.FieldType.BOOLEAN) + .addField( + "address", + Schema.FieldType.row( + Schema.builder() + .addField("city", Schema.FieldType.STRING) + .addField("street", Schema.FieldType.STRING) + .addField("state", Schema.FieldType.STRING) + .addField("zip_code", Schema.FieldType.STRING) + .build())) + .build(); + + @Test + public void testProtoSchemaToBeamSchema() { + Schema schema = ProtoByteUtils.getBeamSchemaFromProto(DESCRIPTOR_PATH, MESSAGE_NAME); + Assert.assertEquals(schema.getFieldNames(), SCHEMA.getFieldNames()); + } + + @Test + public void testProtoBytesToRowFunctionGenerateSerializableFunction() { + SerializableFunction protoBytesToRowFunction = + ProtoByteUtils.getProtoBytesToRowFunction(DESCRIPTOR_PATH, MESSAGE_NAME); + Assert.assertNotNull(protoBytesToRowFunction); + } + + @Test(expected = java.lang.RuntimeException.class) + public void testProtoBytesToRowFunctionReturnsRowFailure() { + // Create a proto bytes to row function + SerializableFunction protoBytesToRowFunction = + ProtoByteUtils.getProtoBytesToRowFunction(DESCRIPTOR_PATH, MESSAGE_NAME); + + // Create some test input bytes that are not matching + byte[] inputBytes = new byte[] {1, 2, 3, 4, 5}; + + // Call the proto bytes to row function that should fail because the input does not match + protoBytesToRowFunction.apply(inputBytes); + } + + @Test + public void testRowToProtoFunction() { + Row row = + Row.withSchema(SCHEMA) + .withFieldValue("id", 1234) + .withFieldValue("name", "Doe") + .withFieldValue("active", false) + .withFieldValue("address.city", "seattle") + .withFieldValue("address.street", "fake street") + .withFieldValue("address.zip_code", "TO-1234") + .withFieldValue("address.state", "wa") + .build(); + + Assert.assertNotNull( + ProtoByteUtils.getRowToProtoBytes(DESCRIPTOR_PATH, MESSAGE_NAME).apply(row)); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/resources/README.md b/sdks/java/extensions/protobuf/src/test/resources/README.md index 79083f5142b0..de9cb742788b 100644 --- a/sdks/java/extensions/protobuf/src/test/resources/README.md +++ b/sdks/java/extensions/protobuf/src/test/resources/README.md @@ -32,3 +32,9 @@ protoc \ --include_imports \ sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto ``` +```bash +protoc \ + -Isdks/java/extensions/protobuf/src/test/resources/ \ + --descriptor_set_out=sdks/java/extensions/protobuf/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb \ + sdks/java/extensions/protobuf/src/test/resources/proto_byte/proto_byte_utils.proto +``` diff --git a/sdks/java/extensions/protobuf/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb b/sdks/java/extensions/protobuf/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb new file mode 100644 index 000000000000..67e93cc177cc --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb @@ -0,0 +1,13 @@ + +ú +test_proto.proto"Ý + MyMessage +id (Rid +name ( Rname +active (Ractive, +address ( 2.MyMessage.AddressRaddressf +Address +street ( Rstreet +city ( Rcity +state ( Rstate +zip_code ( RzipCodebproto3 \ No newline at end of file diff --git a/sdks/java/extensions/protobuf/src/test/resources/proto_byte/proto_byte_utils.proto b/sdks/java/extensions/protobuf/src/test/resources/proto_byte/proto_byte_utils.proto new file mode 100644 index 000000000000..aead141f4b9a --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/proto_byte/proto_byte_utils.proto @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +message MyMessage { + int32 id = 1; + string name = 2; + bool active = 3; + + // Nested field + message Address { + string street = 1; + string city = 2; + string state = 3; + string zip_code = 4; + } + + Address address = 4; +} diff --git a/sdks/java/io/kafka/README.md b/sdks/java/io/kafka/README.md index 4ecf095bec5b..b137e0b240a9 100644 --- a/sdks/java/io/kafka/README.md +++ b/sdks/java/io/kafka/README.md @@ -47,3 +47,13 @@ complete list. The documentation is maintained in JavaDoc for KafkaIO class. It includes usage examples and primary concepts. - [KafkaIO.java](src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java) + +### Protobuf tests +This recreates the proto descriptor set included in this resource directory. + +```bash +protoc \ + -Isdks/java/io/kafka/src/test/resources/ \ + --descriptor_set_out=sdks/java/io/kafka/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb \ + sdks/java/io/kafka/src/test/resources/proto_byte/proto_byte_utils.proto +``` \ No newline at end of file diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index 61209aa50928..dc190ef9d8fd 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -51,6 +51,7 @@ dependencies { permitUnusedDeclared library.java.jackson_dataformat_csv implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(":sdks:java:extensions:avro") + implementation project(":sdks:java:extensions:protobuf") implementation project(":runners:core-construction-java") implementation project(":sdks:java:expansion-service") permitUnusedDeclared project(":sdks:java:expansion-service") // BEAM-11761 diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java index ae03c49b9b04..2fa365b1c7f3 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java @@ -40,7 +40,7 @@ public abstract class KafkaReadSchemaTransformConfiguration { public static final Set VALID_START_OFFSET_VALUES = Sets.newHashSet("earliest", "latest"); - public static final String VALID_FORMATS_STR = "RAW,AVRO,JSON"; + public static final String VALID_FORMATS_STR = "RAW,AVRO,JSON,PROTO"; public static final Set VALID_DATA_FORMATS = Sets.newHashSet(VALID_FORMATS_STR.split(",")); @@ -87,6 +87,18 @@ public static Builder builder() { @Nullable public abstract String getSchema(); + @SchemaFieldDescription( + "The path to the Protocol Buffer File Descriptor Set file. This file is used for schema" + + " definition and message serialization.") + @Nullable + public abstract String getFileDescriptorPath(); + + @SchemaFieldDescription( + "The name of the Protocol Buffer message to be used for schema" + + " extraction and data conversion.") + @Nullable + public abstract String getMessageName(); + @SchemaFieldDescription( "What to do when there is no initial offset in Kafka or if the current offset" + " does not exist any more on the server. (1) earliest: automatically reset the offset to the earliest" @@ -123,6 +135,10 @@ public abstract static class Builder { public abstract Builder setSchema(String schema); + public abstract Builder setFileDescriptorPath(String fileDescriptorPath); + + public abstract Builder setMessageName(String messageName); + public abstract Builder setFormat(String format); public abstract Builder setAutoOffsetResetConfig(String startOffset); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java index 2f5278cd7129..996976ee9a75 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.kafka; +import static org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformConfiguration.VALID_DATA_FORMATS; + import com.google.auto.service.AutoService; import java.io.FileOutputStream; import java.io.IOException; @@ -35,6 +37,7 @@ import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; @@ -95,9 +98,6 @@ protected Class configurationClass() { return KafkaReadSchemaTransformConfiguration.class; } - @SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) - }) @Override protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configuration) { final String inputSchema = configuration.getSchema(); @@ -115,8 +115,14 @@ protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configurati String format = configuration.getFormat(); boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); - - if ((format != null && format.equals("RAW")) || (!Strings.isNullOrEmpty(inputSchema))) { + String descriptorPath = configuration.getFileDescriptorPath(); + String messageName = configuration.getMessageName(); + + if ((format != null && VALID_DATA_FORMATS.contains(format)) + || (!Strings.isNullOrEmpty(inputSchema) && !Objects.equals(format, "RAW")) + || (Objects.equals(format, "PROTO") + && !Strings.isNullOrEmpty(descriptorPath) + && !Strings.isNullOrEmpty(messageName))) { SerializableFunction valueMapper; Schema beamSchema; if (format != null && format.equals("RAW")) { @@ -126,11 +132,21 @@ protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configurati } beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); valueMapper = getRawBytesToRowFunction(beamSchema); + } else if (format != null && format.equals("PROTO")) { + if (descriptorPath == null || messageName == null) { + throw new IllegalArgumentException( + "Expecting both descriptorPath and messageName to be non-null."); + } + valueMapper = ProtoByteUtils.getProtoBytesToRowFunction(descriptorPath, messageName); + beamSchema = ProtoByteUtils.getBeamSchemaFromProto(descriptorPath, messageName); } else { assert Strings.isNullOrEmpty(configuration.getConfluentSchemaRegistryUrl()) : "To read from Kafka, a schema must be provided directly or though Confluent " + "Schema Registry, but not both."; - + if (inputSchema == null) { + throw new IllegalArgumentException( + "To read from Kafka in JSON or AVRO format, you must provide a schema."); + } beamSchema = Objects.equals(format, "JSON") ? JsonUtils.beamSchemaFromJsonSchema(inputSchema) @@ -170,7 +186,11 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { PCollection errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); if (handleErrors) { - outputRows = outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput); + ErrorHandling errorHandling = configuration.getErrorHandling(); + if (errorHandling == null) { + throw new IllegalArgumentException("You must specify an error handling option."); + } + outputRows = outputRows.and(errorHandling.getOutput(), errorOutput); } return outputRows; } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java index 93d6d73493d2..694c3e9f2c14 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java @@ -27,6 +27,7 @@ import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.schemas.AutoValueSchema; @@ -36,14 +37,13 @@ import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; -import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; import org.apache.beam.sdk.schemas.utils.JsonUtils; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.Row; @@ -62,12 +62,14 @@ public class KafkaWriteSchemaTransformProvider extends TypedSchemaTransformProvider< KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransformConfiguration> { - public static final String SUPPORTED_FORMATS_STR = "RAW,JSON,AVRO"; + public static final String SUPPORTED_FORMATS_STR = "RAW,JSON,AVRO,PROTO"; public static final Set SUPPORTED_FORMATS = Sets.newHashSet(SUPPORTED_FORMATS_STR.split(",")); public static final TupleTag ERROR_TAG = new TupleTag() {}; public static final TupleTag> OUTPUT_TAG = new TupleTag>() {}; + public static final Schema ERROR_SCHEMA = + Schema.builder().addStringField("error").addNullableByteArrayField("row").build(); private static final Logger LOG = LoggerFactory.getLogger(KafkaWriteSchemaTransformProvider.class); @@ -99,38 +101,25 @@ static final class KafkaWriteSchemaTransform extends SchemaTransform implements } public static class ErrorCounterFn extends DoFn> { - private final SerializableFunction toBytesFn; - private final Counter errorCounter; + private SerializableFunction toBytesFn; + private Counter errorCounter; private Long errorsInBundle = 0L; - private final boolean handleErrors; - private final Schema errorSchema; - - public ErrorCounterFn( - String name, - SerializableFunction toBytesFn, - Schema errorSchema, - boolean handleErrors) { + + public ErrorCounterFn(String name, SerializableFunction toBytesFn) { this.toBytesFn = toBytesFn; - this.errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, name); - this.handleErrors = handleErrors; - this.errorSchema = errorSchema; + errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, name); } @ProcessElement public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) { - KV output = null; try { - output = KV.of(new byte[1], toBytesFn.apply(row)); + receiver.get(OUTPUT_TAG).output(KV.of(new byte[1], toBytesFn.apply(row))); } catch (Exception e) { - if (!handleErrors) { - throw new RuntimeException(e); - } errorsInBundle += 1; LOG.warn("Error while processing the element", e); - receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, row, e)); - } - if (output != null) { - receiver.get(OUTPUT_TAG).output(output); + receiver + .get(ERROR_TAG) + .output(Row.withSchema(ERROR_SCHEMA).addValues(e.toString(), row.toString()).build()); } } @@ -141,9 +130,6 @@ public void finish() { } } - @SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) - }) @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { Schema inputSchema = input.get("input").getSchema(); @@ -153,24 +139,32 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { if (numFields != 1) { throw new IllegalArgumentException("Expecting exactly one field, found " + numFields); } + if (inputSchema.getField(0).getType().equals(Schema.FieldType.BYTES)) { + throw new IllegalArgumentException( + "The input schema must have exactly one field of type byte."); + } toBytesFn = getRowToRawBytesFunction(inputSchema.getField(0).getName()); } else if (configuration.getFormat().equals("JSON")) { toBytesFn = JsonUtils.getRowToJsonBytesFunction(inputSchema); + } else if (configuration.getFormat().equals("PROTO")) { + String descriptorPath = configuration.getFileDescriptorPath(); + String messageName = configuration.getMessageName(); + if (descriptorPath == null || messageName == null) { + throw new IllegalArgumentException( + "Expecting both descriptorPath and messageName to be non-null."); + } + toBytesFn = ProtoByteUtils.getRowToProtoBytes(descriptorPath, messageName); } else { toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema); } - boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); final Map configOverrides = configuration.getProducerConfigUpdates(); - Schema errorSchema = ErrorHandling.errorSchema(inputSchema); PCollectionTuple outputTuple = input .get("input") .apply( "Map rows to Kafka messages", - ParDo.of( - new ErrorCounterFn( - "Kafka-write-error-counter", toBytesFn, errorSchema, handleErrors)) + ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", toBytesFn)) .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); outputTuple @@ -186,11 +180,8 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { .withKeySerializer(ByteArraySerializer.class) .withValueSerializer(ByteArraySerializer.class)); - // TODO: include output from KafkaIO Write once updated from PDone - PCollection errorOutput = - outputTuple.get(ERROR_TAG).setRowSchema(ErrorHandling.errorSchema(errorSchema)); return PCollectionRowTuple.of( - handleErrors ? configuration.getErrorHandling().getOutput() : "errors", errorOutput); + "errors", outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA)); } } @@ -241,6 +232,18 @@ public abstract static class KafkaWriteSchemaTransformConfiguration implements S + " of servers. | Format: host1:port1,host2:port2,...") public abstract String getBootstrapServers(); + @SchemaFieldDescription( + "The path to the Protocol Buffer File Descriptor Set file. This file is used for schema" + + " definition and message serialization.") + @Nullable + public abstract String getFileDescriptorPath(); + + @SchemaFieldDescription( + "The name of the Protocol Buffer message to be used for schema" + + " extraction and data conversion.") + @Nullable + public abstract String getMessageName(); + @SchemaFieldDescription( "A list of key-value pairs that act as configuration parameters for Kafka producers." + " Most of these configurations will not be needed, but if you need to customize your Kafka producer," @@ -249,10 +252,6 @@ public abstract static class KafkaWriteSchemaTransformConfiguration implements S @Nullable public abstract Map getProducerConfigUpdates(); - @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") - @Nullable - public abstract ErrorHandling getErrorHandling(); - public static Builder builder() { return new AutoValue_KafkaWriteSchemaTransformProvider_KafkaWriteSchemaTransformConfiguration .Builder(); @@ -266,9 +265,11 @@ public abstract static class Builder { public abstract Builder setBootstrapServers(String bootstrapServers); - public abstract Builder setProducerConfigUpdates(Map producerConfigUpdates); + public abstract Builder setFileDescriptorPath(String fileDescriptorPath); + + public abstract Builder setMessageName(String messageName); - public abstract Builder setErrorHandling(ErrorHandling errorHandling); + public abstract Builder setProducerConfigUpdates(Map producerConfigUpdates); public abstract KafkaWriteSchemaTransformConfiguration build(); } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java index 1570a33c2580..27fa18715c32 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java @@ -101,7 +101,9 @@ public void testFindTransformAndMakeItWork() { "format", "confluentSchemaRegistrySubject", "confluentSchemaRegistryUrl", - "errorHandling"), + "errorHandling", + "fileDescriptorPath", + "messageName"), kafkaProvider.configurationSchema().getFields().stream() .map(field -> field.getName()) .collect(Collectors.toSet())); @@ -150,7 +152,7 @@ public void testBuildTransformWithJsonSchema() throws IOException { } @Test - public void testBuildTransformWithRawFormat() throws IOException { + public void testBuildTransformWithRawFormat() { ServiceLoader serviceLoader = ServiceLoader.load(SchemaTransformProvider.class); List providers = @@ -166,4 +168,56 @@ public void testBuildTransformWithRawFormat() throws IOException { .setFormat("RAW") .build()); } + + @Test + public void testBuildTransformWithProtoFormat() { + ServiceLoader serviceLoader = + ServiceLoader.load(SchemaTransformProvider.class); + List providers = + StreamSupport.stream(serviceLoader.spliterator(), false) + .filter(provider -> provider.getClass() == KafkaReadSchemaTransformProvider.class) + .collect(Collectors.toList()); + KafkaReadSchemaTransformProvider kafkaProvider = + (KafkaReadSchemaTransformProvider) providers.get(0); + + kafkaProvider.from( + KafkaReadSchemaTransformConfiguration.builder() + .setTopic("anytopic") + .setBootstrapServers("anybootstrap") + .setFormat("PROTO") + .setMessageName("MyMessage") + .setFileDescriptorPath( + Objects.requireNonNull( + getClass().getResource("/proto_byte/file_descriptor/proto_byte_utils.pb")) + .getPath()) + .build()); + } + + @Test + public void testBuildTransformWithProtoFormatWrongMessageName() { + ServiceLoader serviceLoader = + ServiceLoader.load(SchemaTransformProvider.class); + List providers = + StreamSupport.stream(serviceLoader.spliterator(), false) + .filter(provider -> provider.getClass() == KafkaReadSchemaTransformProvider.class) + .collect(Collectors.toList()); + KafkaReadSchemaTransformProvider kafkaProvider = + (KafkaReadSchemaTransformProvider) providers.get(0); + + assertThrows( + NullPointerException.class, + () -> + kafkaProvider.from( + KafkaReadSchemaTransformConfiguration.builder() + .setTopic("anytopic") + .setBootstrapServers("anybootstrap") + .setFormat("PROTO") + .setMessageName("MyOtherMessage") + .setFileDescriptorPath( + Objects.requireNonNull( + getClass() + .getResource("/proto_byte/file_descriptor/proto_byte_utils.pb")) + .getPath()) + .build())); + } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java index b8649727f36d..20f474790cc7 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java @@ -21,10 +21,12 @@ import java.io.UnsupportedEncodingException; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Objects; +import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils; import org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.ErrorCounterFn; import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; import org.apache.beam.sdk.schemas.utils.JsonUtils; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -52,8 +54,37 @@ public class KafkaWriteSchemaTransformProviderTest { private static final Schema BEAMSCHEMA = Schema.of(Schema.Field.of("name", Schema.FieldType.STRING)); - private static final Schema BEAMRAWSCHEMA = + private static final Schema BEAM_RAW_SCHEMA = Schema.of(Schema.Field.of("payload", Schema.FieldType.BYTES)); + private static final Schema ERRORSCHEMA = KafkaWriteSchemaTransformProvider.ERROR_SCHEMA; + + private static final Schema BEAM_PROTO_SCHEMA = + Schema.builder() + .addField("id", Schema.FieldType.INT32) + .addField("name", Schema.FieldType.STRING) + .addField("active", Schema.FieldType.BOOLEAN) + .addField( + "address", + Schema.FieldType.row( + Schema.builder() + .addField("city", Schema.FieldType.STRING) + .addField("street", Schema.FieldType.STRING) + .addField("state", Schema.FieldType.STRING) + .addField("zip_code", Schema.FieldType.STRING) + .build())) + .build(); + + private static final List PROTO_ROWS = + Collections.singletonList( + Row.withSchema(BEAM_PROTO_SCHEMA) + .withFieldValue("id", 1234) + .withFieldValue("name", "Doe") + .withFieldValue("active", false) + .withFieldValue("address.city", "seattle") + .withFieldValue("address.street", "fake street") + .withFieldValue("address.zip_code", "TO-1234") + .withFieldValue("address.state", "wa") + .build()); private static final List ROWS = Arrays.asList( @@ -67,9 +98,13 @@ public class KafkaWriteSchemaTransformProviderTest { try { RAW_ROWS = Arrays.asList( - Row.withSchema(BEAMRAWSCHEMA).withFieldValue("payload", "a".getBytes("UTF8")).build(), - Row.withSchema(BEAMRAWSCHEMA).withFieldValue("payload", "b".getBytes("UTF8")).build(), - Row.withSchema(BEAMRAWSCHEMA) + Row.withSchema(BEAM_RAW_SCHEMA) + .withFieldValue("payload", "a".getBytes("UTF8")) + .build(), + Row.withSchema(BEAM_RAW_SCHEMA) + .withFieldValue("payload", "b".getBytes("UTF8")) + .build(), + Row.withSchema(BEAM_RAW_SCHEMA) .withFieldValue("payload", "c".getBytes("UTF8")) .build()); } catch (UnsupportedEncodingException e) { @@ -82,6 +117,13 @@ public class KafkaWriteSchemaTransformProviderTest { final SerializableFunction valueRawMapper = getRowToRawBytesFunction("payload"); + final SerializableFunction protoValueRawMapper = + ProtoByteUtils.getRowToProtoBytes( + Objects.requireNonNull( + getClass().getResource("/proto_byte/file_descriptor/proto_byte_utils.pb")) + .getPath(), + "MyMessage"); + @Rule public transient TestPipeline p = TestPipeline.create(); @Test @@ -93,14 +135,12 @@ public void testKafkaErrorFnSuccess() throws Exception { KV.of(new byte[1], "{\"name\":\"c\"}".getBytes("UTF8"))); PCollection input = p.apply(Create.of(ROWS)); - Schema errorSchema = ErrorHandling.errorSchema(BEAMSCHEMA); PCollectionTuple output = input.apply( - ParDo.of( - new ErrorCounterFn("Kafka-write-error-counter", valueMapper, errorSchema, true)) + ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", valueMapper)) .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); - output.get(ERROR_TAG).setRowSchema(errorSchema); + output.get(ERROR_TAG).setRowSchema(ERRORSCHEMA); PAssert.that(output.get(OUTPUT_TAG)).containsInAnyOrder(msg); p.run().waitUntilFinish(); @@ -115,17 +155,26 @@ public void testKafkaErrorFnRawSuccess() throws Exception { KV.of(new byte[1], "c".getBytes("UTF8"))); PCollection input = p.apply(Create.of(RAW_ROWS)); - Schema errorSchema = ErrorHandling.errorSchema(BEAMRAWSCHEMA); PCollectionTuple output = input.apply( - ParDo.of( - new ErrorCounterFn( - "Kafka-write-error-counter", valueRawMapper, errorSchema, true)) + ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", valueRawMapper)) .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); - output.get(ERROR_TAG).setRowSchema(errorSchema); + output.get(ERROR_TAG).setRowSchema(ERRORSCHEMA); PAssert.that(output.get(OUTPUT_TAG)).containsInAnyOrder(msg); p.run().waitUntilFinish(); } + + @Test + public void testKafkaErrorFnProtoSuccess() { + PCollection input = p.apply(Create.of(PROTO_ROWS)); + PCollectionTuple output = + input.apply( + ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", protoValueRawMapper)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + + PAssert.that(output.get(ERROR_TAG).setRowSchema(ERRORSCHEMA)).empty(); + p.run().waitUntilFinish(); + } } diff --git a/sdks/java/io/kafka/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb b/sdks/java/io/kafka/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb new file mode 100644 index 000000000000..67e93cc177cc --- /dev/null +++ b/sdks/java/io/kafka/src/test/resources/proto_byte/file_descriptor/proto_byte_utils.pb @@ -0,0 +1,13 @@ + +ú +test_proto.proto"Ý + MyMessage +id (Rid +name ( Rname +active (Ractive, +address ( 2.MyMessage.AddressRaddressf +Address +street ( Rstreet +city ( Rcity +state ( Rstate +zip_code ( RzipCodebproto3 \ No newline at end of file diff --git a/sdks/java/io/kafka/src/test/resources/proto_byte/proto_byte_utils.proto b/sdks/java/io/kafka/src/test/resources/proto_byte/proto_byte_utils.proto new file mode 100644 index 000000000000..aead141f4b9a --- /dev/null +++ b/sdks/java/io/kafka/src/test/resources/proto_byte/proto_byte_utils.proto @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +message MyMessage { + int32 id = 1; + string name = 2; + bool active = 3; + + // Nested field + message Address { + string street = 1; + string city = 2; + string state = 3; + string zip_code = 4; + } + + Address address = 4; +} diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 116d200a86a7..b19c1e5b063e 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -62,12 +62,16 @@ 'confluent_schema_registry_subject': 'confluentSchemaRegistrySubject' 'auto_offset_reset_config': 'autoOffsetResetConfig' 'error_handling': 'errorHandling' + 'file_descriptor_path': 'fileDescriptorPath' + 'message_name': 'messageName' 'WriteToKafka': 'format': 'format' 'topic': 'topic' 'bootstrap_servers': 'bootstrapServers' 'producer_config_updates': 'ProducerConfigUpdates' 'error_handling': 'errorHandling' + 'file_descriptor_path': 'fileDescriptorPath' + 'message_name': 'messageName' underlying_provider: type: beamJar transforms: