Skip to content

Commit

Permalink
[YAML] - Kafka Proto String schema (#29835)
Browse files Browse the repository at this point in the history
* [YAML] - Kafka Proto String schema
  • Loading branch information
ffernandez92 authored Jan 9, 2024
1 parent 8aa16df commit 6066af3
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 125 deletions.
2 changes: 2 additions & 0 deletions sdks/java/extensions/protobuf/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ dependencies {
implementation library.java.slf4j_api
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation library.java.protobuf_java
implementation("com.squareup.wire:wire-schema-jvm:4.9.3")
implementation("io.apicurio:apicurio-registry-protobuf-schema-utilities:3.0.0.M2")
testImplementation project(path: ":sdks:java:core", configuration: "shadowTest")
testImplementation library.java.junit
testRuntimeOnly library.java.slf4j_jdk14
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.squareup.wire.schema.Location;
import com.squareup.wire.schema.internal.parser.ProtoFileElement;
import com.squareup.wire.schema.internal.parser.ProtoParser;
import io.apicurio.registry.utils.protobuf.schema.FileDescriptorUtils;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
Expand Down Expand Up @@ -55,6 +59,8 @@ public class ProtoByteUtils {

private static final Logger LOG = LoggerFactory.getLogger(ProtoByteUtils.class);

private static final Location LOCATION = Location.get("");

/**
* Retrieves a Beam Schema from a Protocol Buffer message.
*
Expand All @@ -68,6 +74,68 @@ public static Schema getBeamSchemaFromProto(String fileDescriptorPath, String me
return ProtoDynamicMessageSchema.forDescriptor(protoDomain, messageName).getSchema();
}

/**
* Parses the given Protocol Buffers schema string, retrieves the Descriptor for the specified
* message name, and constructs a Beam Schema from it.
*
* @param schemaString The Protocol Buffers schema string.
* @param messageName The name of the message type for which the Beam Schema is desired.
* @return The Beam Schema constructed from the specified Protocol Buffers schema.
* @throws RuntimeException If there is an error during parsing, descriptor retrieval, or schema
* construction.
*/
public static Schema getBeamSchemaFromProtoSchema(String schemaString, String messageName) {
Descriptors.Descriptor descriptor = getDescriptorFromProtoSchema(schemaString, messageName);
return ProtoDynamicMessageSchema.forDescriptor(ProtoDomain.buildFrom(descriptor), descriptor)
.getSchema();
}

/**
* Parses the given Protocol Buffers schema string, retrieves the FileDescriptor, and returns the
* Descriptor for the specified message name.
*
* @param schemaString The Protocol Buffers schema string.
* @param messageName The name of the message type for which the descriptor is desired.
* @return The Descriptor for the specified message name.
* @throws RuntimeException If there is an error during parsing or descriptor validation.
*/
private static Descriptors.Descriptor getDescriptorFromProtoSchema(
final String schemaString, final String messageName) {
ProtoFileElement result = ProtoParser.Companion.parse(LOCATION, schemaString);
try {
Descriptors.FileDescriptor fileDescriptor =
FileDescriptorUtils.protoFileToFileDescriptor(result);
return fileDescriptor.findMessageTypeByName(messageName);
} catch (Descriptors.DescriptorValidationException e) {
throw new RuntimeException(e);
}
}

public static SerializableFunction<byte[], Row> getProtoBytesToRowFromSchemaFunction(
String schemaString, String messageName) {

Descriptors.Descriptor descriptor = getDescriptorFromProtoSchema(schemaString, messageName);

ProtoDynamicMessageSchema<DynamicMessage> protoDynamicMessageSchema =
ProtoDynamicMessageSchema.forDescriptor(ProtoDomain.buildFrom(descriptor), descriptor);
return new SimpleFunction<byte[], Row>() {
@Override
public Row apply(byte[] input) {
try {
Descriptors.Descriptor descriptorFunction =
getDescriptorFromProtoSchema(schemaString, messageName);
DynamicMessage dynamicMessage = DynamicMessage.parseFrom(descriptorFunction, input);
SerializableFunction<DynamicMessage, Row> res =
protoDynamicMessageSchema.getToRowFunction();
return res.apply(dynamicMessage);
} catch (InvalidProtocolBufferException e) {
LOG.error("Error parsing to DynamicMessage", e);
throw new RuntimeException(e);
}
}
};
}

public static SerializableFunction<byte[], Row> getProtoBytesToRowFunction(
String fileDescriptorPath, String messageName) {

Expand Down Expand Up @@ -96,6 +164,23 @@ public Row apply(byte[] input) {
};
}

public static SerializableFunction<Row, byte[]> getRowToProtoBytesFromSchema(
String schemaString, String messageName) {

Descriptors.Descriptor descriptor = getDescriptorFromProtoSchema(schemaString, messageName);

ProtoDynamicMessageSchema<DynamicMessage> protoDynamicMessageSchema =
ProtoDynamicMessageSchema.forDescriptor(ProtoDomain.buildFrom(descriptor), descriptor);
return new SimpleFunction<Row, byte[]>() {
@Override
public byte[] apply(Row input) {
SerializableFunction<Row, DynamicMessage> res =
protoDynamicMessageSchema.getFromRowFunction();
return res.apply(input).toByteArray();
}
};
}

public static SerializableFunction<Row, byte[]> getRowToProtoBytes(
String fileDescriptorPath, String messageName) {
ProtoSchemaInfo dynamicProtoDomain = getProtoDomain(fileDescriptorPath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@
@RunWith(JUnit4.class)
public class ProtoByteUtilsTest {

private static final String PROTO_STRING_SCHEMA =
"syntax = \"proto3\";\n"
+ "\n"
+ "message MyMessage {\n"
+ " int32 id = 1;\n"
+ " string name = 2;\n"
+ " bool active = 3;\n"
+ "\n"
+ " // Nested field\n"
+ " message Address {\n"
+ " string street = 1;\n"
+ " string city = 2;\n"
+ " string state = 3;\n"
+ " string zip_code = 4;\n"
+ " }\n"
+ "\n"
+ " Address address = 4;\n"
+ "}";

private static final String DESCRIPTOR_PATH =
Objects.requireNonNull(
ProtoByteUtilsTest.class.getResource(
Expand Down Expand Up @@ -59,13 +78,26 @@ public void testProtoSchemaToBeamSchema() {
Assert.assertEquals(schema.getFieldNames(), SCHEMA.getFieldNames());
}

@Test
public void testProtoSchemaStringToBeamSchema() {
Schema schema = ProtoByteUtils.getBeamSchemaFromProtoSchema(PROTO_STRING_SCHEMA, "MyMessage");
Assert.assertEquals(schema.getFieldNames(), SCHEMA.getFieldNames());
}

@Test
public void testProtoBytesToRowFunctionGenerateSerializableFunction() {
SerializableFunction<byte[], Row> protoBytesToRowFunction =
ProtoByteUtils.getProtoBytesToRowFunction(DESCRIPTOR_PATH, MESSAGE_NAME);
Assert.assertNotNull(protoBytesToRowFunction);
}

@Test
public void testProtoBytesToRowSchemaStringGenerateSerializableFunction() {
SerializableFunction<byte[], Row> protoBytesToRowFunction =
ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(PROTO_STRING_SCHEMA, "MyMessage");
Assert.assertNotNull(protoBytesToRowFunction);
}

@Test(expected = java.lang.RuntimeException.class)
public void testProtoBytesToRowFunctionReturnsRowFailure() {
// Create a proto bytes to row function
Expand Down Expand Up @@ -95,4 +127,21 @@ public void testRowToProtoFunction() {
Assert.assertNotNull(
ProtoByteUtils.getRowToProtoBytes(DESCRIPTOR_PATH, MESSAGE_NAME).apply(row));
}

@Test
public void testRowToProtoSchemaFunction() {
Row row =
Row.withSchema(SCHEMA)
.withFieldValue("id", 1234)
.withFieldValue("name", "Doe")
.withFieldValue("active", false)
.withFieldValue("address.city", "seattle")
.withFieldValue("address.street", "fake street")
.withFieldValue("address.zip_code", "TO-1234")
.withFieldValue("address.state", "wa")
.build();

Assert.assertNotNull(
ProtoByteUtils.getRowToProtoBytesFromSchema(PROTO_STRING_SCHEMA, "MyMessage").apply(row));
}
}
1 change: 1 addition & 0 deletions sdks/java/io/kafka/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ dependencies {
testImplementation project(path: ":sdks:java:core", configuration: "shadowTest")
testImplementation project(":sdks:java:io:synthetic")
testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration")
testImplementation project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntimeMigration")
testImplementation project(path: ":sdks:java:io:common", configuration: "testRuntimeMigration")
testImplementation project(path: ":sdks:java:testing:test-utils", configuration: "testRuntimeMigration")
// For testing Cross-language transforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ public void validate() {
final String dataFormat = this.getFormat();
assert dataFormat == null || VALID_DATA_FORMATS.contains(dataFormat)
: "Valid data formats are " + VALID_DATA_FORMATS;

final String inputSchema = this.getSchema();
final String messageName = this.getMessageName();
final String fileDescriptorPath = this.getFileDescriptorPath();
final String confluentSchemaRegUrl = this.getConfluentSchemaRegistryUrl();
final String confluentSchemaRegSubject = this.getConfluentSchemaRegistrySubject();

if (confluentSchemaRegUrl != null) {
assert confluentSchemaRegSubject != null
: "To read from Kafka, a schema must be provided directly or though Confluent "
+ "Schema Registry. Make sure you are providing one of these parameters.";
} else if (dataFormat != null && dataFormat.equals("RAW")) {
assert inputSchema == null : "To read from Kafka in RAW format, you can't provide a schema.";
} else if (dataFormat != null && dataFormat.equals("JSON")) {
assert inputSchema != null : "To read from Kafka in JSON format, you must provide a schema.";
} else if (dataFormat != null && dataFormat.equals("PROTO")) {
assert messageName != null
: "To read from Kafka in PROTO format, messageName must be provided.";
assert fileDescriptorPath != null || inputSchema != null
: "To read from Kafka in PROTO format, fileDescriptorPath or schema must be provided.";
} else {
assert inputSchema != null : "To read from Kafka in AVRO format, you must provide a schema.";
}
}

/** Instantiates a {@link KafkaReadSchemaTransformConfiguration.Builder} instance. */
Expand Down
Loading

0 comments on commit 6066af3

Please sign in to comment.