Skip to content

Commit

Permalink
Updates Expansion Service Container to support upgrading using the sc…
Browse files Browse the repository at this point in the history
…hema-transform ID (#31451)

* Updates Expansion Service to support upgrading using the schema-transform ID

* Fixes a test failure

* Triggerring Iceberg test suite

* Adding a unit test
  • Loading branch information
chamikaramj authored May 31, 2024
1 parent 19d57d0 commit 7b6f941
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 78 deletions.
3 changes: 3 additions & 0 deletions .github/trigger_files/IO_Iceberg_Integration_Tests.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
}
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,9 @@ class BeamModulePlugin implements Plugin<Project> {
aws_java_sdk2_profiles : "software.amazon.awssdk:profiles:$aws_java_sdk2_version",
azure_sdk_bom : "com.azure:azure-sdk-bom:1.2.14",
bigdataoss_gcsio : "com.google.cloud.bigdataoss:gcsio:$google_cloud_bigdataoss_version",
bigdataoss_gcs_connector : "com.google.cloud.bigdataoss:gcs-connector:hadoop2-$google_cloud_bigdataoss_version",
bigdataoss_util : "com.google.cloud.bigdataoss:util:$google_cloud_bigdataoss_version",
bigdataoss_util_hadoop : "com.google.cloud.bigdataoss:util-hadoop:hadoop2-$google_cloud_bigdataoss_version",
byte_buddy : "net.bytebuddy:byte-buddy:1.14.12",
cassandra_driver_core : "com.datastax.cassandra:cassandra-driver-core:$cassandra_driver_version",
cassandra_driver_mapping : "com.datastax.cassandra:cassandra-driver-mapping:$cassandra_driver_version",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

public interface ExternalTranslationOptions extends PipelineOptions {

@Description("Set of URNs of transforms to be overriden using the transform service.")
@Description(
"Set of URNs of transforms to be overriden using the transform service. The provided strings "
+ "can be transform URNs of schema-transform IDs")
@Default.InstanceFactory(EmptyListDefault.class)
List<String> getTransformsToOverride();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ public RunnerApi.PTransform translate(
if (underlyingIdentifier == null) {
throw new IllegalStateException(
String.format(
"Encountered a Managed Transform that has an empty \"transform_identifier\": \n%s",
"Encountered a Managed Transform that has an empty \"transform_identifier\": %n%s",
configRow));
}
transformBuilder.putAnnotations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.util.construction;

import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;

import com.fasterxml.jackson.core.Version;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand Down Expand Up @@ -51,6 +53,7 @@
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannelBuilder;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter;
Expand Down Expand Up @@ -113,6 +116,22 @@ public RunnerApi.Pipeline upgradeTransformsViaTransformService(
if (urn != null && urnsToOverride.contains(urn)) {
return true;
}

// Also check if the URN is a schema-transform ID.
if (urn.equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))) {
try {
ExternalTransforms.SchemaTransformPayload schemaTransformPayload =
ExternalTransforms.SchemaTransformPayload.parseFrom(
entry.getValue().getSpec().getPayload());
String schemaTransformId = schemaTransformPayload.getIdentifier();
if (urnsToOverride.contains(schemaTransformId)) {
return true;
}
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}

return false;
})
.map(
Expand Down Expand Up @@ -184,28 +203,35 @@ RunnerApi.Pipeline updateTransformViaTransformService(
if (transformToUpgrade == null) {
throw new IllegalArgumentException("Could not find a transform with the ID " + transformId);
}
ByteString configRowBytes =
transformToUpgrade.getAnnotationsOrThrow(
BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_KEY));
ByteString configRowSchemaBytes =
transformToUpgrade.getAnnotationsOrThrow(
BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_SCHEMA_KEY));
SchemaApi.Schema configRowSchemaProto =
SchemaApi.Schema.parseFrom(configRowSchemaBytes.toByteArray());

ExternalTransforms.ExternalConfigurationPayload payload =
ExternalTransforms.ExternalConfigurationPayload.newBuilder()
.setSchema(configRowSchemaProto)
.setPayload(configRowBytes)
.build();

byte[] payloadBytes = null;

if (!transformToUpgrade.getSpec().getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))) {
ByteString configRowBytes =
transformToUpgrade.getAnnotationsOrThrow(
BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_KEY));
ByteString configRowSchemaBytes =
transformToUpgrade.getAnnotationsOrThrow(
BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_SCHEMA_KEY));
SchemaApi.Schema configRowSchemaProto =
SchemaApi.Schema.parseFrom(configRowSchemaBytes.toByteArray());
payloadBytes =
ExternalTransforms.ExternalConfigurationPayload.newBuilder()
.setSchema(configRowSchemaProto)
.setPayload(configRowBytes)
.build()
.toByteArray();
} else {
payloadBytes = transformToUpgrade.getSpec().getPayload().toByteArray();
}

RunnerApi.PTransform.Builder ptransformBuilder =
RunnerApi.PTransform.newBuilder()
.setUniqueName(transformToUpgrade.getUniqueName() + "_external")
.setSpec(
RunnerApi.FunctionSpec.newBuilder()
.setUrn(transformToUpgrade.getSpec().getUrn())
.setPayload(ByteString.copyFrom(payload.toByteArray()))
.setPayload(ByteString.copyFrom(payloadBytes))
.build());

for (Map.Entry<String, String> entry : transformToUpgrade.getInputsMap().entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,22 @@
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.ToString;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.util.construction.PTransformTranslation.TransformPayloadTranslator;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -158,6 +164,60 @@ public static class Registrar2 implements TransformPayloadTranslatorRegistrar {
}
}

public static class TestSchemaTransformProvider implements SchemaTransformProvider {

@Override
public String identifier() {
return "dummy_schema_transform";
}

@Override
public Schema configurationSchema() {
return Schema.builder().build();
}

@Override
public SchemaTransform from(Row configuration) {
return new TestSchemaTransform();
}
}

public static class TestSchemaTransform extends SchemaTransform {

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
return input;
}
}

static class TestSchemaTransformTranslator
extends SchemaTransformPayloadTranslator<TestSchemaTransform> {
@Override
public SchemaTransformProvider provider() {
return new TestSchemaTransformProvider();
}

@Override
public Row toConfigRow(TestSchemaTransform transform) {
return Row.withSchema(Schema.builder().build()).build();
}
}

@AutoService(TransformPayloadTranslatorRegistrar.class)
public static class TestSchemaTransformPayloadTranslatorRegistrar
implements TransformPayloadTranslatorRegistrar {
@Override
@SuppressWarnings({
"rawtypes",
})
public Map<? extends Class<? extends PTransform>, ? extends TransformPayloadTranslator>
getTransformPayloadTranslators() {
return ImmutableMap.<Class<? extends PTransform>, TransformPayloadTranslator>builder()
.put(TestSchemaTransform.class, new TestSchemaTransformTranslator())
.build();
}
}

static class TestExpansionServiceClientFactory implements ExpansionServiceClientFactory {
ExpansionApi.ExpansionResponse response;

Expand All @@ -183,34 +243,49 @@ public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest reque
.getTransformsMap()
.get("TransformUpgraderTest-TestTransform2");
}

boolean schemaTransformTest = false;
if (transformToUpgrade == null) {
// This is running a schema-transform test.
transformToUpgrade =
request
.getComponents()
.getTransformsMap()
.get("TransformUpgraderTest-TestSchemaTransform");
schemaTransformTest = true;
}

if (!transformToUpgrade
.getSpec()
.getUrn()
.equals(request.getTransform().getSpec().getUrn())) {
throw new RuntimeException("Could not find a valid transform to upgrade");
}

Integer oldParam;
try {
ByteArrayInputStream byteArrayInputStream =
new ByteArrayInputStream(transformToUpgrade.getSpec().getPayload().toByteArray());
ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
oldParam = (Integer) objectInputStream.readObject();
} catch (Exception e) {
throw new RuntimeException(e);
}

RunnerApi.PTransform.Builder upgradedTransform = transformToUpgrade.toBuilder();
FunctionSpec.Builder specBuilder = upgradedTransform.getSpecBuilder();

ByteStringOutputStream byteStringOutputStream = new ByteStringOutputStream();
try {
ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteStringOutputStream);
objectOutputStream.writeObject(oldParam * 2);
objectOutputStream.flush();
specBuilder.setPayload(byteStringOutputStream.toByteString());
} catch (IOException e) {
throw new RuntimeException(e);
if (!schemaTransformTest) {
Integer oldParam;
try {
ByteArrayInputStream byteArrayInputStream =
new ByteArrayInputStream(transformToUpgrade.getSpec().getPayload().toByteArray());
ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
oldParam = (Integer) objectInputStream.readObject();
} catch (Exception e) {
throw new RuntimeException(e);
}

ByteStringOutputStream byteStringOutputStream = new ByteStringOutputStream();
try {
ObjectOutputStream objectOutputStream =
new ObjectOutputStream(byteStringOutputStream);
objectOutputStream.writeObject(oldParam * 2);
objectOutputStream.flush();
specBuilder.setPayload(byteStringOutputStream.toByteString());
} catch (IOException e) {
throw new RuntimeException(e);
}
}

upgradedTransform.setSpec(specBuilder.build());
Expand Down Expand Up @@ -291,6 +366,34 @@ public void testTransformUpgrade() throws Exception {
assertTrue(upgradedTransform.getAnnotationsMap().containsKey(TransformUpgrader.UPGRADE_KEY));
}

@Test
public void testTransformUpgradeSchemaTransform() throws Exception {
Pipeline pipeline = Pipeline.create();

// Build the pipeline
PCollectionRowTuple.empty(pipeline).apply(new TestSchemaTransform());

RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, false);
ExternalTranslationOptions options =
PipelineOptionsFactory.create().as(ExternalTranslationOptions.class);
List<String> urnsToOverride = ImmutableList.of("dummy_schema_transform");
options.setTransformsToOverride(urnsToOverride);
options.setTransformServiceAddress("dummyaddress");

RunnerApi.Pipeline upgradedPipelineProto =
TransformUpgrader.of(new TestExpansionServiceClientFactory())
.upgradeTransformsViaTransformService(pipelineProto, urnsToOverride, options);

RunnerApi.PTransform upgradedTransform =
upgradedPipelineProto
.getComponents()
.getTransformsMap()
.get("TransformUpgraderTest-TestSchemaTransform");

// Confirm that the upgraded transform includes the upgrade annotation.
assertTrue(upgradedTransform.getAnnotationsMap().containsKey(TransformUpgrader.UPGRADE_KEY));
}

@Test
public void testTransformUpgradeMultipleOccurrences() throws Exception {
Pipeline pipeline = Pipeline.create();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ allowlist:
- "beam:transform:org.apache.beam:schemaio_jdbc_read:v1"
- "beam:transform:org.apache.beam:schemaio_jdbc_write:v1"
- "beam:schematransform:org.apache.beam:bigquery_storage_write:v1"
# By default, the Expansion Service container will include all dependencies in
# the classpath. Following config can be used to override this behavior per
# transform URN or schema-transform ID.
dependencies:
"beam:transform:org.apache.beam:kafka_read_with_metadata:v1":
- path: "jars/beam-sdks-java-io-expansion-service.jar"
Expand Down
Loading

0 comments on commit 7b6f941

Please sign in to comment.