Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default translation for SchemaTransforms #31558

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
package org.apache.beam.sdk.schemas.transforms;

import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
* An abstraction representing schema capable and aware transforms. The interface is intended to be
Expand All @@ -33,5 +38,60 @@
* compatibility guarantees and it should not be implemented outside of the Beam repository.
*/
@Internal
public abstract class SchemaTransform
extends PTransform<PCollectionRowTuple, PCollectionRowTuple> {}
public abstract class SchemaTransform extends PTransform<PCollectionRowTuple, PCollectionRowTuple> {
private @Nullable Row configurationRow;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like putting these (optional) private variables as part of the base class violates separation of concerns.

private @Nullable String identifier;
private boolean registered = false;
Copy link
Contributor Author

@ahmedabu98 ahmedabu98 Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use "registered" throughout, but I'm not sure if it's the best terminology here. Open to other suggestions!


/**
* Stores the transform's identifier and configuration {@link Row} used to build this instance.
* Doing so allows this transform to be translated from/to proto using {@link
* org.apache.beam.sdk.util.construction.PTransformTranslation.SchemaTransformTranslator}.
*/
public SchemaTransform register(Row configurationRow, String identifier) {
this.configurationRow = configurationRow;
this.identifier = identifier;
registered = true;

return this;
}

/**
* Like {@link #register(Row, String)}, but with a configuration POJO. This is a convenience
* method for {@link TypedSchemaTransformProvider} implementations.
*/
public <ConfigT> SchemaTransform register(
ConfigT configuration, Class<ConfigT> configClass, String identifier) {
SchemaRegistry registry = SchemaRegistry.createDefault();
try {
// Get initial row with values
// then sort lexicographically and convert to snake_case
Row configRow =
registry.getToRowFunction(configClass).apply(configuration).sorted().toSnakeCase();
return register(configRow, identifier);
} catch (NoSuchSchemaException e) {
throw new RuntimeException(
String.format(
"Unable to find schema for this SchemaTransform's config type: %s", configClass),
e);
}
}

public Row getConfigurationRow() {
return Preconditions.checkNotNull(
configurationRow,
"Could not fetch Row configuration for %s. Please store it using .register().",
getClass());
}

public String getIdentifier() {
return Preconditions.checkNotNull(
identifier,
"Could not fetch identifier for %s. Please store it using .register().",
getClass());
}

public boolean isRegistered() {
return registered;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.sdk.coders.RowCoder;
Expand All @@ -36,13 +39,23 @@
import org.checkerframework.checker.nullness.qual.Nullable;

/**
* A {@link TransformPayloadTranslator} implementation that translates between a Java {@link
* SchemaTransform} and a protobuf payload for that transform.
* A default {@link TransformPayloadTranslator} implementation for {@link SchemaTransform}s.
*
* <p>Note: This is only eligible for registered SchemaTransform instances (using {@link
* SchemaTransform#register}).
*/
public class SchemaTransformTranslation {
public abstract static class SchemaTransformPayloadTranslator<T extends SchemaTransform>
implements TransformPayloadTranslator<T> {
public abstract SchemaTransformProvider provider();
public static class SchemaTransformPayloadTranslator
implements TransformPayloadTranslator<SchemaTransform> {
private final SchemaTransformProvider provider;

public String identifier() {
return provider.identifier();
}

public SchemaTransformPayloadTranslator(SchemaTransformProvider provider) {
this.provider = provider;
}

@Override
public String getUrn() {
Expand All @@ -52,18 +65,19 @@ public String getUrn() {
@Override
@SuppressWarnings("argument")
public @Nullable FunctionSpec translate(
AppliedPTransform<?, ?, T> application, SdkComponents components) throws IOException {
AppliedPTransform<?, ?, SchemaTransform> application, SdkComponents components)
throws IOException {
SchemaApi.Schema expansionSchema =
SchemaTranslation.schemaToProto(provider().configurationSchema(), true);
SchemaTranslation.schemaToProto(provider.configurationSchema(), true);
Row configRow = toConfigRow(application.getTransform());
ByteArrayOutputStream os = new ByteArrayOutputStream();
RowCoder.of(provider().configurationSchema()).encode(configRow, os);
RowCoder.of(provider.configurationSchema()).encode(configRow, os);

return FunctionSpec.newBuilder()
.setUrn(getUrn())
.setPayload(
ExternalTransforms.SchemaTransformPayload.newBuilder()
.setIdentifier(provider().identifier())
.setIdentifier(provider.identifier())
.setConfigurationSchema(expansionSchema)
.setConfigurationRow(ByteString.copyFrom(os.toByteArray()))
.build()
Expand All @@ -72,8 +86,21 @@ public String getUrn() {
}

@Override
public T fromConfigRow(Row configRow, PipelineOptions options) {
return (T) provider().from(configRow);
public Row toConfigRow(SchemaTransform transform) {
return transform.getConfigurationRow();
}

@Override
public SchemaTransform fromConfigRow(Row configRow, PipelineOptions options) {
return provider.from(configRow);
}
}

public static Map<String, SchemaTransformPayloadTranslator> getDefaultTranslators() {
Map<String, SchemaTransformPayloadTranslator> translators = new HashMap<>();
for (SchemaTransformProvider provider : ServiceLoader.load(SchemaTransformProvider.class)) {
translators.put(provider.identifier(), new SchemaTransformPayloadTranslator(provider));
}
return translators;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.CoderUtils;
Expand All @@ -59,6 +61,7 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSortedSet;
Expand Down Expand Up @@ -253,6 +256,7 @@ private static Collection<TransformTranslator<?>> loadKnownTranslators() {
.add(new KnownTransformPayloadTranslator())
.add(ParDoTranslator.create())
.add(ExternalTranslator.create())
.add(new SchemaTransformTranslator())
.build();
}

Expand Down Expand Up @@ -581,6 +585,84 @@ public RunnerApi.PTransform translate(
}
}

/**
* Translates {@link SchemaTransform}s by populating the {@link FunctionSpec} with a {@link
* ExternalTransforms.SchemaTransformPayload} containing the transform's configuration {@link
* Schema} and {@link Row}.
*
* <p>This can be used as a default translator for SchemaTransforms. If further customization is
* needed, you can develop a {@link TransformPayloadTranslator} implementation and include it in a
* {@link TransformPayloadTranslatorRegistrar}, which will be picked up by {@link
* KnownTransformPayloadTranslator}.
*
* <p>Note: This default translator is only eligible for registered SchemaTransform instances
* (using {@link SchemaTransform#register}).
*/
private static class SchemaTransformTranslator implements TransformTranslator<SchemaTransform> {
@Override
public @Nullable String getUrn(SchemaTransform transform) {
return transform.getIdentifier();
}

@Override
public boolean canTranslate(PTransform transform) {
// Can translate only if the SchemaTransform's configuration Row and identifier are accessible
if (transform instanceof SchemaTransform) {
return (((SchemaTransform) transform).isRegistered());
}
return false;
}

@Override
public RunnerApi.PTransform translate(
AppliedPTransform<?, ?, ?> appliedPTransform,
List<AppliedPTransform<?, ?, ?>> subtransforms,
SdkComponents components)
throws IOException {
RunnerApi.PTransform.Builder transformBuilder =
translateAppliedPTransform(appliedPTransform, subtransforms, components);

String identifier = ((SchemaTransform) appliedPTransform.getTransform()).getIdentifier();
TransformPayloadTranslator payloadTranslator =
SchemaTransformTranslation.getDefaultTranslators().get(identifier);

FunctionSpec spec = payloadTranslator.translate(appliedPTransform, components);
Row configRow = payloadTranslator.toConfigRow(appliedPTransform.getTransform());

if (spec != null) {
transformBuilder.setSpec(spec);
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.SCHEMATRANSFORM_URN_KEY),
ByteString.copyFromUtf8(identifier));
if (identifier.equals(MANAGED_TRANSFORM_URN)) {
String underlyingIdentifier =
Preconditions.checkNotNull(
configRow.getString("transform_identifier"),
"Encountered a Managed Transform that has an empty \"transform_identifier\": %n%s",
configRow);
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.MANAGED_UNDERLYING_TRANSFORM_URN_KEY),
ByteString.copyFromUtf8(underlyingIdentifier));
}
}
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_KEY),
ByteString.copyFrom(
CoderUtils.encodeToByteArray(RowCoder.of(configRow.getSchema()), configRow)));
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_SCHEMA_KEY),
ByteString.copyFrom(
SchemaTranslation.schemaToProto(configRow.getSchema(), true).toByteArray()));

for (Entry<String, byte[]> annotation :
appliedPTransform.getTransform().getAnnotations().entrySet()) {
transformBuilder.putAnnotations(
annotation.getKey(), ByteString.copyFrom(annotation.getValue()));
}

return transformBuilder.build();
}
}
/**
* Translates an {@link AppliedPTransform} by:
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@
import org.apache.beam.sdk.testing.UsesSchema;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test for {@link Select}. */
@RunWith(JUnit4.class)
@Category(UsesSchema.class)
public class TypedSchemaTransformProviderTest {
@Rule public ExpectedException thrown = ExpectedException.none();

/** flat schema to select from. */
@DefaultSchema(AutoValueSchema.class)
Expand Down Expand Up @@ -105,7 +108,7 @@ public String identifier() {

@Override
public SchemaTransform from(Configuration config) {
return new FakeSchemaTransform(config);
return new FakeSchemaTransform(config).register(config, Configuration.class, identifier());
}
}

Expand All @@ -123,6 +126,38 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
}
}

@Test
public void testInferConfigurationClass() {
assertEquals(Configuration.class, new FakeTypedSchemaIOProvider().configurationClass());
assertEquals(Configuration.class, new FakeMinimalTypedProvider().configurationClass());
}

@Test
public void testGetConfigurationRow() {
FakeMinimalTypedProvider minimalProvider = new FakeMinimalTypedProvider();
Configuration inputConfig = Configuration.create("field1", 13);
SchemaTransform transform = minimalProvider.from(inputConfig);

Row expectedConfig =
Row.withSchema(minimalProvider.configurationSchema())
.withFieldValue("string_field", "field1")
.withFieldValue("integer_field", 13)
.build();

assertEquals(expectedConfig, transform.getConfigurationRow());
assertEquals(minimalProvider.identifier(), transform.getIdentifier());

// FakeTypedSchemaIOProvider doesn't register its schematransform.
// Check that a helpful error message is returned.
FakeTypedSchemaIOProvider fakeProvider = new FakeTypedSchemaIOProvider();
SchemaTransform unregisteredTransform = fakeProvider.from(inputConfig);
thrown.expect(NullPointerException.class);
thrown.expectMessage("Could not fetch Row configuration");
thrown.expectMessage("FakeSchemaTransform");
thrown.expectMessage("Please store it using .register()");
unregisteredTransform.getConfigurationRow();
}

@Test
public void testFrom() {
SchemaTransformProvider provider = new FakeTypedSchemaIOProvider();
Expand Down
Loading
Loading