Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default translation for SchemaTransforms #31558

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
* An abstraction representing schema capable and aware transforms. The interface is intended to be
Expand All @@ -33,5 +36,39 @@
* compatibility guarantees and it should not be implemented outside of the Beam repository.
*/
@Internal
public abstract class SchemaTransform
extends PTransform<PCollectionRowTuple, PCollectionRowTuple> {}
public abstract class SchemaTransform extends PTransform<PCollectionRowTuple, PCollectionRowTuple> {
private @Nullable Row configurationRow;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like putting these (optional) private variables as part of the base class violates separation of concerns.

private @Nullable String identifier;
private boolean registered = false;
Copy link
Contributor Author

@ahmedabu98 ahmedabu98 Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use "registered" throughout, but I'm not sure if it's the best terminology here. Open to other suggestions!


/**
* Stores the transform's identifier and configuration {@link Row} used to build this instance.
* Doing so allows this transform to be translated from/to proto using {@link
* org.apache.beam.sdk.util.construction.PTransformTranslation.SchemaTransformTranslator}.
*/
public SchemaTransform register(Row configurationRow, String identifier) {
this.configurationRow = configurationRow;
this.identifier = identifier;
registered = true;

return this;
}

public Row getConfigurationRow() {
return Preconditions.checkNotNull(
configurationRow,
"Could not fetch Row configuration for %s. Please store it using .register().",
getClass());
}

public String getIdentifier() {
return Preconditions.checkNotNull(
identifier,
"Could not fetch identifier for %s. Please store it using .register().",
getClass());
}

public boolean isRegistered() {
return registered;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.sdk.coders.RowCoder;
Expand All @@ -36,13 +39,24 @@
import org.checkerframework.checker.nullness.qual.Nullable;

/**
* A {@link TransformPayloadTranslator} implementation that translates between a Java {@link
* SchemaTransform} and a protobuf payload for that transform.
* A default {@link TransformPayloadTranslator} implementation for {@link SchemaTransform}s.
*
* <p>Note: This is only eligible for registered SchemaTransform instances (using {@link
* SchemaTransform#register(Row, String)} or {@link TypedSchemaTransformProvider#register(Object,
* SchemaTransform)}).
*/
public class SchemaTransformTranslation {
public abstract static class SchemaTransformPayloadTranslator<T extends SchemaTransform>
implements TransformPayloadTranslator<T> {
public abstract SchemaTransformProvider provider();
public static class SchemaTransformPayloadTranslator
implements TransformPayloadTranslator<SchemaTransform> {
private final SchemaTransformProvider provider;

public String identifier() {
return provider.identifier();
}

public SchemaTransformPayloadTranslator(SchemaTransformProvider provider) {
this.provider = provider;
}

@Override
public String getUrn() {
Expand All @@ -52,18 +66,19 @@ public String getUrn() {
@Override
@SuppressWarnings("argument")
public @Nullable FunctionSpec translate(
AppliedPTransform<?, ?, T> application, SdkComponents components) throws IOException {
AppliedPTransform<?, ?, SchemaTransform> application, SdkComponents components)
throws IOException {
SchemaApi.Schema expansionSchema =
SchemaTranslation.schemaToProto(provider().configurationSchema(), true);
SchemaTranslation.schemaToProto(provider.configurationSchema(), true);
Row configRow = toConfigRow(application.getTransform());
ByteArrayOutputStream os = new ByteArrayOutputStream();
RowCoder.of(provider().configurationSchema()).encode(configRow, os);
RowCoder.of(provider.configurationSchema()).encode(configRow, os);

return FunctionSpec.newBuilder()
.setUrn(getUrn())
.setPayload(
ExternalTransforms.SchemaTransformPayload.newBuilder()
.setIdentifier(provider().identifier())
.setIdentifier(provider.identifier())
.setConfigurationSchema(expansionSchema)
.setConfigurationRow(ByteString.copyFrom(os.toByteArray()))
.build()
Expand All @@ -72,8 +87,21 @@ public String getUrn() {
}

@Override
public T fromConfigRow(Row configRow, PipelineOptions options) {
return (T) provider().from(configRow);
public Row toConfigRow(SchemaTransform transform) {
return transform.getConfigurationRow();
}

@Override
public SchemaTransform fromConfigRow(Row configRow, PipelineOptions options) {
return provider.from(configRow);
}
}

public static Map<String, SchemaTransformPayloadTranslator> getDefaultTranslators() {
Map<String, SchemaTransformPayloadTranslator> translators = new HashMap<>();
for (SchemaTransformProvider provider : ServiceLoader.load(SchemaTransformProvider.class)) {
translators.put(provider.identifier(), new SchemaTransformPayloadTranslator(provider));
}
return translators;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ protected Class<ConfigT> configurationClass() {
return (Class<ConfigT>) parameterizedType.getActualTypeArguments()[0];
}

/** Like {@link SchemaTransform#register(Row, String)}, but with a configuration POJO. */
protected SchemaTransform register(ConfigT configuration, SchemaTransform transform) {
SchemaRegistry registry = SchemaRegistry.createDefault();
try {
// Get initial row with values
// then sort lexicographically and convert to snake_case
Row configRow =
registry
.getToRowFunction(configurationClass())
.apply(configuration)
.sorted()
.toSnakeCase();
return transform.register(configRow, identifier());
} catch (NoSuchSchemaException e) {
throw new RuntimeException(
String.format(
"Unable to find schema for this SchemaTransform's config type: %s",
configurationClass()),
e);
}
}

/**
* Produce a SchemaTransform from ConfigT. Can throw a {@link InvalidConfigurationException} or a
* {@link InvalidSchemaException}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.CoderUtils;
Expand All @@ -59,6 +61,7 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSortedSet;
Expand Down Expand Up @@ -253,6 +256,7 @@ private static Collection<TransformTranslator<?>> loadKnownTranslators() {
.add(new KnownTransformPayloadTranslator())
.add(ParDoTranslator.create())
.add(ExternalTranslator.create())
.add(new SchemaTransformTranslator())
.build();
}

Expand Down Expand Up @@ -581,6 +585,86 @@ public RunnerApi.PTransform translate(
}
}

/**
* Translates {@link SchemaTransform}s by populating the {@link FunctionSpec} with a {@link
* ExternalTransforms.SchemaTransformPayload} containing the transform's configuration {@link
* Schema} and {@link Row}.
*
* <p>This can be used as a default translator for SchemaTransforms. If further customization is
* needed, you can develop a {@link TransformPayloadTranslator} implementation and include it in a
* {@link TransformPayloadTranslatorRegistrar}, which will be picked up by {@link
* KnownTransformPayloadTranslator}.
*
* <p>Note: This default translator is only eligible for registered SchemaTransform instances
* (using {@link SchemaTransform#register(Row, String)} or {@link
* org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider#register(Object,
* SchemaTransform)}).
*/
private static class SchemaTransformTranslator implements TransformTranslator<SchemaTransform> {
@Override
public @Nullable String getUrn(SchemaTransform transform) {
return transform.getIdentifier();
}

@Override
public boolean canTranslate(PTransform transform) {
// Can translate only if the SchemaTransform's configuration Row and identifier are accessible
if (transform instanceof SchemaTransform) {
return (((SchemaTransform) transform).isRegistered());
}
return false;
}

@Override
public RunnerApi.PTransform translate(
AppliedPTransform<?, ?, ?> appliedPTransform,
List<AppliedPTransform<?, ?, ?>> subtransforms,
SdkComponents components)
throws IOException {
RunnerApi.PTransform.Builder transformBuilder =
translateAppliedPTransform(appliedPTransform, subtransforms, components);

String identifier = ((SchemaTransform) appliedPTransform.getTransform()).getIdentifier();
TransformPayloadTranslator payloadTranslator =
SchemaTransformTranslation.getDefaultTranslators().get(identifier);

FunctionSpec spec = payloadTranslator.translate(appliedPTransform, components);
Row configRow = payloadTranslator.toConfigRow(appliedPTransform.getTransform());

if (spec != null) {
transformBuilder.setSpec(spec);
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.SCHEMATRANSFORM_URN_KEY),
ByteString.copyFromUtf8(identifier));
if (identifier.equals(MANAGED_TRANSFORM_URN)) {
String underlyingIdentifier =
Preconditions.checkNotNull(
configRow.getString("transform_identifier"),
"Encountered a Managed Transform that has an empty \"transform_identifier\": %n%s",
configRow);
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.MANAGED_UNDERLYING_TRANSFORM_URN_KEY),
ByteString.copyFromUtf8(underlyingIdentifier));
}
}
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_KEY),
ByteString.copyFrom(
CoderUtils.encodeToByteArray(RowCoder.of(configRow.getSchema()), configRow)));
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_SCHEMA_KEY),
ByteString.copyFrom(
SchemaTranslation.schemaToProto(configRow.getSchema(), true).toByteArray()));

for (Entry<String, byte[]> annotation :
appliedPTransform.getTransform().getAnnotations().entrySet()) {
transformBuilder.putAnnotations(
annotation.getKey(), ByteString.copyFrom(annotation.getValue()));
}

return transformBuilder.build();
}
}
/**
* Translates an {@link AppliedPTransform} by:
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.transforms;

import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;
import static org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator;
import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.util.construction.BeamUrns;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
import org.junit.Test;

/** Base class for standard {@link SchemaTransform} translation tests. */
public abstract class SchemaTransformTranslationTest {
protected abstract SchemaTransformProvider provider();

protected abstract Row configurationRow();

/** Input used for this SchemaTransform. Used to build a pipeline to test proto translation. */
protected PCollectionRowTuple input(Pipeline p) {
return PCollectionRowTuple.empty(p);
};

@Test
public void testRecreateTransformFromRow() {
SchemaTransformProvider provider = provider();
SchemaTransformPayloadTranslator translator = new SchemaTransformPayloadTranslator(provider);
SchemaTransform originalTransform = provider.from(configurationRow());

Row translatedConfigRow = translator.toConfigRow(originalTransform);
SchemaTransform translatedTransform =
translator.fromConfigRow(translatedConfigRow, PipelineOptionsFactory.create());

assertEquals(configurationRow(), translatedTransform.getConfigurationRow());
}

@Test
public void testTransformProtoTranslation() throws InvalidProtocolBufferException, IOException {
SchemaTransformProvider provider = provider();
Row configurationRow = configurationRow();

// Infer if it's a read or write SchemaTransform and build pipeline accordingly
Pipeline p = Pipeline.create();
SchemaTransform schemaTransform = provider.from(configurationRow);
input(p).apply(schemaTransform);

// Then translate the pipeline to a proto and extract the SchemaTransform's proto
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
List<RunnerApi.PTransform> schemaTransformProto =
pipelineProto.getComponents().getTransformsMap().values().stream()
.filter(
tr -> {
RunnerApi.FunctionSpec spec = tr.getSpec();
try {
return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))
&& ExternalTransforms.SchemaTransformPayload.parseFrom(spec.getPayload())
.getIdentifier()
.equals(provider.identifier());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
assertEquals(1, schemaTransformProto.size());
RunnerApi.FunctionSpec spec = schemaTransformProto.get(0).getSpec();

// Check that the proto contains correct values
ExternalTransforms.SchemaTransformPayload payload =
ExternalTransforms.SchemaTransformPayload.parseFrom(spec.getPayload());
Schema translatedSchema = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
assertEquals(provider.configurationSchema(), translatedSchema);
Row translatedConfigRow =
RowCoder.of(translatedSchema).decode(payload.getConfigurationRow().newInput());

assertEquals(configurationRow, translatedConfigRow);

// Use the information in the proto to recreate the transform
SchemaTransform translatedTransform =
new SchemaTransformPayloadTranslator(provider)
.fromConfigRow(translatedConfigRow, PipelineOptionsFactory.create());

assertEquals(configurationRow, translatedTransform.getConfigurationRow());
}
}
Loading
Loading