From 3da18c8821b7d722ee8fbb8712a7322fb4df9a83 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 28 Sep 2023 12:19:36 -0700 Subject: [PATCH 1/3] Ensure configuration schema used to decode configuration. Just because schemas are compatible doesn't mean that they're identical (e.g. up to field ordering). This is important as schema generation in Java is not always determanistic, so even if the payload is encoded with the exact schema that was previously provided it might not agree when the request comes in. --- .../service/ExpansionServiceSchemaTransformProvider.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java index 8d74f2f6117a..ead1fa67dc98 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java @@ -129,8 +129,7 @@ public PTransform getTransform(FunctionSpec spec) { Row configRow; try { configRow = - RowCoder.of(provider.configurationSchema()) - .decode(payload.getConfigurationRow().newInput()); + RowCoder.of(configSchemaFromRequest).decode(payload.getConfigurationRow().newInput()); } catch (IOException e) { throw new RuntimeException("Error decoding payload", e); } From 79cadac9f337aa7cd309ca1386b28b9ddab4c7c5 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 28 Sep 2023 13:43:17 -0700 Subject: [PATCH 2/3] Add some tests. --- ...ionServiceSchemaTransformProviderTest.java | 112 +++++++++++------- 1 file changed, 72 insertions(+), 40 deletions(-) diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java index 141d2b48b105..27f8083fcd32 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java @@ -19,9 +19,9 @@ import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import com.google.auto.service.AutoService; -import java.io.IOException; import java.util.ArrayList; import java.util.List; import org.apache.beam.model.expansion.v1.ExpansionApi; @@ -32,6 +32,7 @@ import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.core.construction.PipelineTranslation; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.schemas.JavaFieldSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -48,10 +49,11 @@ import org.apache.beam.sdk.transforms.InferableFunction; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.InvalidProtocolBufferException; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; @@ -74,6 +76,13 @@ public class ExpansionServiceSchemaTransformProviderTest { Field.of("str1", FieldType.STRING), Field.of("str2", FieldType.STRING)); + private static final Schema TEST_SCHEMATRANSFORM_EQUIVALENT_CONFIG_SCHEMA = + Schema.of( + Field.of("str2", FieldType.STRING), + Field.of("str1", FieldType.STRING), + Field.of("int2", FieldType.INT32), + Field.of("int1", FieldType.INT32)); + private ExpansionService expansionService = new ExpansionService(); @DefaultSchema(JavaFieldSchema.class) @@ -344,31 +353,13 @@ public void testSchemaTransformExpansion() { .withFieldValue("str2", "bbb") .build(); - ByteStringOutputStream outputStream = new ByteStringOutputStream(); - try { - SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } - - ExternalTransforms.SchemaTransformPayload payload = - ExternalTransforms.SchemaTransformPayload.newBuilder() - .setIdentifier("dummy_id") - .setConfigurationRow(outputStream.toByteString()) - .setConfigurationSchema( - SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true)) - .build(); - ExpansionApi.ExpansionRequest request = ExpansionApi.ExpansionRequest.newBuilder() .setComponents(pipelineProto.getComponents()) .setTransform( RunnerApi.PTransform.newBuilder() .setUniqueName(TEST_NAME) - .setSpec( - RunnerApi.FunctionSpec.newBuilder() - .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) - .setPayload(payload.toByteString())) + .setSpec(createSpec("dummy_id", configRow)) .putInputs("input1", inputPcollId)) .setNamespace(TEST_NAMESPACE) .build(); @@ -403,35 +394,18 @@ public void testSchemaTransformExpansionMultiInputMultiOutput() { .withFieldValue("str2", "bbb") .build(); - ByteStringOutputStream outputStream = new ByteStringOutputStream(); - try { - SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } - - ExternalTransforms.SchemaTransformPayload payload = - ExternalTransforms.SchemaTransformPayload.newBuilder() - .setIdentifier("dummy_id_multi_input_multi_output") - .setConfigurationRow(outputStream.toByteString()) - .setConfigurationSchema( - SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true)) - .build(); - ExpansionApi.ExpansionRequest request = ExpansionApi.ExpansionRequest.newBuilder() .setComponents(pipelineProto.getComponents()) .setTransform( RunnerApi.PTransform.newBuilder() .setUniqueName(TEST_NAME) - .setSpec( - RunnerApi.FunctionSpec.newBuilder() - .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) - .setPayload(payload.toByteString())) + .setSpec(createSpec("dummy_id_multi_input_multi_output", configRow)) .putInputs("input1", inputPcollIds.get(0)) .putInputs("input2", inputPcollIds.get(1))) .setNamespace(TEST_NAMESPACE) .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); RunnerApi.PTransform expandedTransform = response.getTransform(); @@ -440,4 +414,62 @@ public void testSchemaTransformExpansionMultiInputMultiOutput() { assertEquals(2, expandedTransform.getOutputsCount()); verifyLeafTransforms(response, 2); } + + @Test + public void testSchematransformEquivalentConfigSchema() throws CoderException { + Row configRow = + Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA) + .withFieldValue("int1", 111) + .withFieldValue("int2", 222) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") + .build(); + + RunnerApi.FunctionSpec spec = createSpec("dummy_id", configRow); + + Row equivalentConfigRow = + Row.withSchema(TEST_SCHEMATRANSFORM_EQUIVALENT_CONFIG_SCHEMA) + .withFieldValue("int1", 111) + .withFieldValue("int2", 222) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") + .build(); + + RunnerApi.FunctionSpec equivalentSpec = createSpec("dummy_id", equivalentConfigRow); + + assertNotEquals(spec.getPayload(), equivalentSpec.getPayload()); + + TestSchemaTransform transform = + (TestSchemaTransform) ExpansionServiceSchemaTransformProvider.of().getTransform(spec); + TestSchemaTransform equivalentTransform = + (TestSchemaTransform) + ExpansionServiceSchemaTransformProvider.of().getTransform(equivalentSpec); + + assertEquals(transform.int1, equivalentTransform.int1); + assertEquals(transform.int2, equivalentTransform.int2); + assertEquals(transform.str1, equivalentTransform.str1); + assertEquals(transform.str2, equivalentTransform.str2); + } + + private RunnerApi.FunctionSpec createSpec(String identifier, Row configRow) { + byte[] encodedRow; + try { + encodedRow = CoderUtils.encodeToByteArray(SchemaCoder.of(configRow.getSchema()), configRow); + } catch (CoderException e) { + throw new RuntimeException(e); + } + ; + + ExternalTransforms.SchemaTransformPayload payload = + ExternalTransforms.SchemaTransformPayload.newBuilder() + .setIdentifier(identifier) + .setConfigurationRow(ByteString.copyFrom(encodedRow)) + .setConfigurationSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), true)) + .build(); + + return RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) + .setPayload(payload.toByteString()) + .build(); + } } From 8a77e46ac142e401b395c56b29185fd3d1b594fb Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 28 Sep 2023 14:25:29 -0700 Subject: [PATCH 3/3] stray line --- .../service/ExpansionServiceSchemaTransformProviderTest.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java index 27f8083fcd32..d7a665eabe0f 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java @@ -458,7 +458,6 @@ private RunnerApi.FunctionSpec createSpec(String identifier, Row configRow) { } catch (CoderException e) { throw new RuntimeException(e); } - ; ExternalTransforms.SchemaTransformPayload payload = ExternalTransforms.SchemaTransformPayload.newBuilder()