diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index aa9e70c7a871..429371e11055 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -111,6 +111,19 @@ message BuilderMethod { bytes payload = 3; } +message Annotations { + enum Enum { + // The annotation key for the encoded configuration Row used to build a transform + CONFIG_ROW_KEY = 0 [(org.apache.beam.model.pipeline.v1.beam_constant) = "config_row"]; + // The annotation key for the configuration Schema used to decode the configuration Row + CONFIG_ROW_SCHEMA_KEY = 1 [(org.apache.beam.model.pipeline.v1.beam_constant) = "config_row_schema"]; + // If ths transform is a SchemaTransform, this is the annotation key for the SchemaTransform's URN + SCHEMATRANSFORM_URN_KEY = 2 [(org.apache.beam.model.pipeline.v1.beam_constant) = "schematransform_urn"]; + // If the transform is a ManagedSchemaTransform, this is the annotation key for the underlying SchemaTransform's URN + MANAGED_UNDERLYING_TRANSFORM_URN_KEY = 3 [(org.apache.beam.model.pipeline.v1.beam_constant) = "managed_underlying_transform_urn"]; + } +} + // Payload for a Schema-aware PTransform. // This is a transform that is aware of its input and output PCollection schemas // and is configured using Beam Schema-compatible parameters. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/BeamUrns.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/BeamUrns.java index 05bb2b0e0a00..f0493de3696a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/BeamUrns.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/BeamUrns.java @@ -26,4 +26,9 @@ public class BeamUrns { public static String getUrn(ProtocolMessageEnum value) { return value.getValueDescriptor().getOptions().getExtension(RunnerApi.beamUrn); } + + /** Returns the constant value of a given enum annotated with [(beam_constant)]. */ + public static String getConstant(ProtocolMessageEnum value) { + return value.getValueDescriptor().getOptions().getExtension(RunnerApi.beamConstant); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PTransformTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PTransformTranslation.java index 5dc84897d380..e2b6d95057fd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PTransformTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PTransformTranslation.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.util.construction; +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations; import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; @@ -43,6 +44,7 @@ import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.display.DisplayData; @@ -94,16 +96,12 @@ public class PTransformTranslation { public static final String MAP_WINDOWS_TRANSFORM_URN = "beam:transform:map_windows:v1"; public static final String MERGE_WINDOWS_TRANSFORM_URN = "beam:transform:merge_windows:v1"; public static final String TO_STRING_TRANSFORM_URN = "beam:transform:to_string:v1"; + public static final String MANAGED_TRANSFORM_URN = "beam:transform:managed:v1"; // Required runner implemented transforms. These transforms should never specify an environment. public static final ImmutableSet RUNNER_IMPLEMENTED_TRANSFORMS = ImmutableSet.of(GROUP_BY_KEY_TRANSFORM_URN, IMPULSE_TRANSFORM_URN); - public static final String CONFIG_ROW_KEY = "config_row"; - - public static final String CONFIG_ROW_SCHEMA_KEY = "config_row_schema"; - public static final String SCHEMATRANSFORM_URN_KEY = "schematransform_urn"; - // DeprecatedPrimitives /** * @deprecated SDKs should move away from creating `Read` transforms and migrate to using Impulse @@ -522,11 +520,28 @@ public RunnerApi.PTransform translate( } if (spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))) { + ExternalTransforms.SchemaTransformPayload payload = + ExternalTransforms.SchemaTransformPayload.parseFrom(spec.getPayload()); + String identifier = payload.getIdentifier(); transformBuilder.putAnnotations( - SCHEMATRANSFORM_URN_KEY, - ByteString.copyFromUtf8( - ExternalTransforms.SchemaTransformPayload.parseFrom(spec.getPayload()) - .getIdentifier())); + BeamUrns.getConstant(Annotations.Enum.SCHEMATRANSFORM_URN_KEY), + ByteString.copyFromUtf8(identifier)); + if (identifier.equals(MANAGED_TRANSFORM_URN)) { + Schema configSchema = + SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + Row configRow = + RowCoder.of(configSchema).decode(payload.getConfigurationRow().newInput()); + String underlyingIdentifier = configRow.getString("transform_identifier"); + if (underlyingIdentifier == null) { + throw new IllegalStateException( + String.format( + "Encountered a Managed Transform that has an empty \"transform_identifier\": \n%s", + configRow)); + } + transformBuilder.putAnnotations( + BeamUrns.getConstant(Annotations.Enum.MANAGED_UNDERLYING_TRANSFORM_URN_KEY), + ByteString.copyFromUtf8(underlyingIdentifier)); + } } } @@ -546,12 +561,12 @@ public RunnerApi.PTransform translate( } if (configRow != null) { transformBuilder.putAnnotations( - CONFIG_ROW_KEY, + BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_KEY), ByteString.copyFrom( CoderUtils.encodeToByteArray(RowCoder.of(configRow.getSchema()), configRow))); transformBuilder.putAnnotations( - CONFIG_ROW_SCHEMA_KEY, + BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_SCHEMA_KEY), ByteString.copyFrom( SchemaTranslation.schemaToProto(configRow.getSchema(), true).toByteArray())); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/TransformUpgrader.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/TransformUpgrader.java index deaa77d9b1be..941a5daf689b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/TransformUpgrader.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/TransformUpgrader.java @@ -185,9 +185,11 @@ RunnerApi.Pipeline updateTransformViaTransformService( throw new IllegalArgumentException("Could not find a transform with the ID " + transformId); } ByteString configRowBytes = - transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_KEY); + transformToUpgrade.getAnnotationsOrThrow( + BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_KEY)); ByteString configRowSchemaBytes = - transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_SCHEMA_KEY); + transformToUpgrade.getAnnotationsOrThrow( + BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_SCHEMA_KEY)); SchemaApi.Schema configRowSchemaProto = SchemaApi.Schema.parseFrom(configRowSchemaBytes.toByteArray()); diff --git a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java index 7a418976079f..0b0ad532dbd4 100644 --- a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java +++ b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java @@ -17,11 +17,16 @@ */ package org.apache.beam.sdk.managed; +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.CONFIG_ROW_KEY; +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.CONFIG_ROW_SCHEMA_KEY; +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.MANAGED_UNDERLYING_TRANSFORM_URN_KEY; +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.SCHEMATRANSFORM_URN_KEY; import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; import static org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; import static org.apache.beam.sdk.managed.ManagedSchemaTransformProvider.ManagedConfig; import static org.apache.beam.sdk.managed.ManagedSchemaTransformProvider.ManagedSchemaTransform; import static org.apache.beam.sdk.managed.ManagedSchemaTransformTranslation.ManagedSchemaTransformTranslator; +import static org.apache.beam.sdk.util.construction.PTransformTranslation.MANAGED_TRANSFORM_URN; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -41,11 +46,13 @@ import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.schemas.utils.YamlUtils; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.construction.BeamUrns; import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.junit.Test; @@ -154,9 +161,38 @@ public void testProtoTranslation() throws Exception { }) .collect(Collectors.toList()); assertEquals(1, managedTransformProto.size()); - RunnerApi.FunctionSpec spec = managedTransformProto.get(0).getSpec(); + RunnerApi.PTransform convertedTransform = managedTransformProto.get(0); - // Check that the proto contains correct values + // Check the transform proto contains the correct annotations. + // These annotations can be accessed and used by the runner to make decisions + Row managedConfigRow = + Row.withSchema(PROVIDER.configurationSchema()) + .withFieldValue("transform_identifier", TestSchemaTransformProvider.IDENTIFIER) + .withFieldValue("config", yamlStringConfig) + .build(); + Map expectedAnnotations = + ImmutableMap.builder() + .put( + BeamUrns.getConstant(SCHEMATRANSFORM_URN_KEY), + ByteString.copyFromUtf8(MANAGED_TRANSFORM_URN)) + .put( + BeamUrns.getConstant(MANAGED_UNDERLYING_TRANSFORM_URN_KEY), + ByteString.copyFromUtf8(TestSchemaTransformProvider.IDENTIFIER)) + .put( + BeamUrns.getConstant(CONFIG_ROW_KEY), + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + RowCoder.of(PROVIDER.configurationSchema()), managedConfigRow))) + .put( + BeamUrns.getConstant(CONFIG_ROW_SCHEMA_KEY), + ByteString.copyFrom( + SchemaTranslation.schemaToProto(PROVIDER.configurationSchema(), true) + .toByteArray())) + .build(); + assertEquals(expectedAnnotations, convertedTransform.getAnnotationsMap()); + + // Check that the spec proto contains correct values + RunnerApi.FunctionSpec spec = convertedTransform.getSpec(); SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); assertEquals(PROVIDER.identifier(), payload.getIdentifier()); Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());