From f3aa0aec446a781b42c8ce3973e30393bfdba333 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Tue, 12 Sep 2023 21:27:46 -0700 Subject: [PATCH 1/8] Upgrade Java transforms without upgrading the pipelines --- runners/core-construction-java/build.gradle | 1 + .../core/construction/CombineTranslation.java | 6 +- .../CreatePCollectionViewTranslation.java | 2 +- .../core/construction/CreateTranslation.java | 128 ++++++ .../ExternalTranslationOptions.java | 43 ++ .../ExternalTranslationOptionsRegistrar.java | 36 ++ .../core/construction/FlattenTranslator.java | 2 +- .../construction/GroupByKeyTranslation.java | 2 +- .../GroupIntoBatchesTranslation.java | 4 +- .../core/construction/ImpulseTranslation.java | 2 +- .../construction/PTransformTranslation.java | 52 ++- .../construction/PipelineTranslation.java | 378 ++++++++++++++++++ .../core/construction/ReadTranslation.java | 4 +- .../construction/ReshuffleTranslation.java | 2 +- .../core/construction/SplittableParDo.java | 2 +- .../construction/TestStreamTranslation.java | 2 +- .../construction/WindowIntoTranslation.java | 5 +- .../construction/WriteFilesTranslation.java | 2 +- .../beam/runners/dataflow/DataflowRunner.java | 5 + .../dataflow/PrimitiveParDoSingleFactory.java | 2 +- .../runners/dataflow/DataflowRunnerTest.java | 3 +- .../apache/beam/sdk/transforms/Create.java | 4 + .../expansion/service/ExpansionService.java | 65 +++ .../gcp/pubsub/PubSubPayloadTranslation.java | 9 +- 24 files changed, 731 insertions(+), 30 deletions(-) create mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java create mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptions.java create mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptionsRegistrar.java diff --git a/runners/core-construction-java/build.gradle b/runners/core-construction-java/build.gradle index feac7c37c8e0..f593865b3fe9 100644 --- a/runners/core-construction-java/build.gradle +++ b/runners/core-construction-java/build.gradle @@ -55,6 +55,7 @@ dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(path: ":sdks:java:extensions:avro") implementation project(path: ":sdks:java:fn-execution") + implementation project(path: ":sdks:java:transform-service:launcher") implementation library.java.vendored_grpc_1_54_0 implementation library.java.vendored_guava_32_1_2_jre implementation library.java.classgraph diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java index 3f902acf250c..fbe4876b414f 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java @@ -61,7 +61,7 @@ public static class CombinePerKeyPayloadTranslator private CombinePerKeyPayloadTranslator() {} @Override - public String getUrn(Combine.PerKey transform) { + public String getUrn() { return COMBINE_PER_KEY_TRANSFORM_URN; } @@ -108,7 +108,7 @@ public static class CombineGloballyPayloadTranslator private CombineGloballyPayloadTranslator() {} @Override - public String getUrn(Combine.Globally transform) { + public String getUrn() { return COMBINE_GLOBALLY_TRANSFORM_URN; } @@ -165,7 +165,7 @@ public static class CombineGroupedValuesPayloadTranslator private CombineGroupedValuesPayloadTranslator() {} @Override - public String getUrn(Combine.GroupedValues transform) { + public String getUrn() { return COMBINE_GROUPED_VALUES_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java index a679737fd616..71038564ec4c 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java @@ -90,7 +90,7 @@ public static PCollectionView getView( static class CreatePCollectionViewTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(View.CreatePCollectionView transform) { + public String getUrn() { return PTransformTranslation.CREATE_VIEW_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java new file mode 100644 index 000000000000..076a5e437e5a --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import com.google.auto.service.AutoService; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Create.Values; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.Row; +import org.checkerframework.checker.nullness.qual.Nullable; + +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" +}) +public class CreateTranslation implements TransformPayloadTranslator> { + + Schema createConfigSchema = + Schema.builder() + .addArrayField("values", FieldType.BYTES) + .addByteArrayField("serialized_coder") + .build(); + + @Override + public String getUrn() { + return PTransformTranslation.CREATE_TRANSFORM_URN; + } + + @Override + public @Nullable FunctionSpec translate( + AppliedPTransform> application, SdkComponents components) throws IOException { + // Currently just returns an empty payload. + // We can implement an actual payload of runners start using this transform. + return FunctionSpec.newBuilder().setUrn(getUrn(application.getTransform())).build(); + } + + private byte[] toByteArray(Object object) { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(object); + return bos.toByteArray(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private Object fromByteArray(byte[] bytes) { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis)) { + return in.readObject(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public @Nullable Row toConfigRow(Values transform) { + List encodedElements = new ArrayList<>(); + transform + .getElements() + .forEach( + object -> { + encodedElements.add(toByteArray(object)); + }); + + byte[] serializedCoder = + transform.getCoder() != null ? toByteArray(transform.getCoder()) : new byte[] {}; + return Row.withSchema(createConfigSchema) + .withFieldValue("values", encodedElements) + .withFieldValue("serialized_coder", serializedCoder) + .build(); + } + + @Override + public Create.@Nullable Values fromConfigRow(Row configRow) { + Values transform = + Create.of( + configRow.getArray("values").stream() + .map(bytesValue -> fromByteArray((byte[]) bytesValue)) + .collect(Collectors.toList())); + byte[] serializedCoder = configRow.getBytes("serialized_coder"); + if (serializedCoder.length > 0) { + Coder coder = (Coder) fromByteArray(serializedCoder); + transform = transform.withCoder(coder); + } + return transform; + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class Registrar implements TransformPayloadTranslatorRegistrar { + @Override + public Map, ? extends TransformPayloadTranslator> + getTransformPayloadTranslators() { + return Collections.singletonMap(Create.Values.class, new CreateTranslation()); + } + } +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptions.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptions.java new file mode 100644 index 000000000000..4b3ef24ca1d2 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptions.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import java.util.List; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.resourcehints.ResourceHintsOptions.EmptyListDefault; + +public interface ExternalTranslationOptions extends PipelineOptions { + + @Description("Set of URNs of transforms to be overriden using the transform service.") + @Default.InstanceFactory(EmptyListDefault.class) + List getTransformsToOverride(); + + void setTransformsToOverride(List transformsToOverride); + + @Description("Address of an already available transform service.") + String getTransformServiceAddress(); + + void setTransformServiceAddress(String transformServiceAddress); + + @Description("An available Beam version which will be used to start a transform service.") + String getTransformServiceBeamVersion(); + + void setTransformServiceBeamVersion(String transformServiceBeamVersion); +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptionsRegistrar.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptionsRegistrar.java new file mode 100644 index 000000000000..6296f4c83775 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslationOptionsRegistrar.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsRegistrar; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + +/** A registrar for ExternalTranslationOptions. */ +@AutoService(PipelineOptionsRegistrar.class) +@Internal +public class ExternalTranslationOptionsRegistrar implements PipelineOptionsRegistrar { + @Override + public Iterable> getPipelineOptions() { + return ImmutableList.>builder() + .add(ExternalTranslationOptions.class) + .build(); + } +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/FlattenTranslator.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/FlattenTranslator.java index 37c09663c5a2..201a65e6233c 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/FlattenTranslator.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/FlattenTranslator.java @@ -43,7 +43,7 @@ public static TransformPayloadTranslator create() { private FlattenTranslator() {} @Override - public String getUrn(Flatten.PCollections transform) { + public String getUrn() { return PTransformTranslation.FLATTEN_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupByKeyTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupByKeyTranslation.java index e6bbbf0767a5..183fa7ffcdc9 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupByKeyTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupByKeyTranslation.java @@ -38,7 +38,7 @@ public class GroupByKeyTranslation { static class GroupByKeyTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(GroupByKey transform) { + public String getUrn() { return PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupIntoBatchesTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupIntoBatchesTranslation.java index c91e9cedb9ac..7c81afd8ae07 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupIntoBatchesTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/GroupIntoBatchesTranslation.java @@ -39,7 +39,7 @@ public class GroupIntoBatchesTranslation { static class GroupIntoBatchesTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(GroupIntoBatches transform) { + public String getUrn() { return PTransformTranslation.GROUP_INTO_BATCHES_URN; } @@ -61,7 +61,7 @@ public RunnerApi.FunctionSpec translate( static class ShardedGroupIntoBatchesTranslator implements TransformPayloadTranslator.WithShardedKey> { @Override - public String getUrn(GroupIntoBatches.WithShardedKey transform) { + public String getUrn() { return PTransformTranslation.GROUP_INTO_BATCHES_WITH_SHARDED_KEY_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ImpulseTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ImpulseTranslation.java index 3de0ce9de8ac..25f0cd7749b5 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ImpulseTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ImpulseTranslation.java @@ -37,7 +37,7 @@ public class ImpulseTranslation { private static class ImpulseTranslator implements TransformPayloadTranslator { @Override - public String getUrn(Impulse transform) { + public String getUrn() { return PTransformTranslation.IMPULSE_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 2acd77885fcc..02d0ceb37e3b 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -69,6 +70,7 @@ public class PTransformTranslation { // and we validate that the value matches the actual URN in the static block below. // Primitives + public static final String CREATE_TRANSFORM_URN = "beam:transform:create:v1"; public static final String PAR_DO_TRANSFORM_URN = "beam:transform:pardo:v1"; public static final String FLATTEN_TRANSFORM_URN = "beam:transform:flatten:v1"; public static final String GROUP_BY_KEY_TRANSFORM_URN = "beam:transform:group_by_key:v1"; @@ -386,9 +388,9 @@ public RunnerApi.PTransform translate( * Translates a set of registered transforms whose content only differs based by differences in * their {@link FunctionSpec}s and URNs. */ - private static class KnownTransformPayloadTranslator> + public static class KnownTransformPayloadTranslator> implements TransformTranslator { - private static final Map, TransformPayloadTranslator> + public static final Map, TransformPayloadTranslator> KNOWN_PAYLOAD_TRANSLATORS = loadTransformPayloadTranslators(); private static Map, TransformPayloadTranslator> @@ -508,14 +510,56 @@ static RunnerApi.PTransform.Builder translateAppliedPTransform( * *

When going to a protocol buffer message, the translator produces a payload corresponding to * the Java representation while registering components that payload references. + * + *

Also, provides methods for generating a Row-based constructor config for the transform that + * can be later used to re-construct the transform. */ public interface TransformPayloadTranslator> { - String getUrn(T transform); + /** + * Provides a unique URN for transforms represented by this {@code TransformPayloadTranslator}. + */ + String getUrn(); + + /** + * Same as {@link #getUrn()} but the returned URN may depend on the transform provided. + * + *

Only override this if the same {@code TransformPayloadTranslator} used for multiple + * transforms. Otherwise, use {@link #getUrn()}. + */ + default String getUrn(T transform) { + return getUrn(); + } + + /** + * Translates the given transform represented by the provided {@code AppliedPTransform} to a + * {@code FunctionSpec} with a URN and a payload. + */ @Nullable FunctionSpec translate(AppliedPTransform application, SdkComponents components) throws IOException; + /** + * Generates a Row-based construction configuration for the provided transform. + * + * @param transform a transform represented by the current {@code TransformPayloadTranslator}. + * @return + */ + default Row toConfigRow(T transform) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** + * Construts a transform from a provided Row-based construction configuration. + * + * @param configRow a construction configuration similar to what would be generated by the + * {@link #toConfigRow(PTransform)} method. + * @return a transform represented by the current {@code TransformPayloadTranslator}. + */ + default T fromConfigRow(Row configRow) { + throw new UnsupportedOperationException("Not implemented"); + } + /** * A {@link TransformPayloadTranslator} for transforms that contain no references to components, * so they do not need a specialized rehydration. @@ -526,7 +570,7 @@ abstract class NotSerializable> public static NotSerializable forUrn(final String urn) { return new NotSerializable>() { @Override - public String getUrn(PTransform transform) { + public String getUrn() { return urn; } }; diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java index 53553e7062b3..4d4c625154c8 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java @@ -18,20 +18,39 @@ package org.apache.beam.runners.core.construction; import java.io.IOException; +import java.net.ServerSocket; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; +import java.util.UUID; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.runners.core.construction.External.ExpandableTransform; +import org.apache.beam.runners.core.construction.PTransformTranslation.KnownTransformPayloadTranslator; +import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; import org.apache.beam.runners.core.construction.graph.PipelineValidator; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy.Node; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transformservice.launcher.TransformServiceLauncher; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ListMultimap; @@ -102,11 +121,370 @@ public void visitPrimitiveTransform(Node node) { // TODO(JIRA-5649): Don't even emit these transforms in the generated protos. res = elideDeprecatedViews(res); } + + ExternalTranslationOptions externalTranslationOptions = + pipeline.getOptions().as(ExternalTranslationOptions.class); + List urnsToOverride = externalTranslationOptions.getTransformsToOverride(); + if (urnsToOverride.size() > 0) { + // We use PTransformPayloadTranslators and the Transform Service to re-generate the pipeline + // proto components for the updated transform and update the pipeline proto. + Map transforms = res.getComponents().getTransformsMap(); + List alreadyCheckedURns = new ArrayList<>(); + for (Entry entry : transforms.entrySet()) { + String urn = entry.getValue().getSpec().getUrn(); + if (!alreadyCheckedURns.contains(urn) && urnsToOverride.contains(urn)) { + alreadyCheckedURns.add(urn); + // All transforms in the pipeline with the given urns have to be overridden. + List< + AppliedPTransform< + PInput, + POutput, + org.apache.beam.sdk.transforms.PTransform>> + appliedPTransforms = + findAppliedPTransforms( + urn, pipeline, KnownTransformPayloadTranslator.KNOWN_PAYLOAD_TRANSLATORS); + for (AppliedPTransform< + PInput, + POutput, + org.apache.beam.sdk.transforms.PTransform> + appliedPTransform : appliedPTransforms) { + TransformPayloadTranslator< + org.apache.beam.sdk.transforms.PTransform> + payloadTranslator = + KnownTransformPayloadTranslator.KNOWN_PAYLOAD_TRANSLATORS.get( + appliedPTransform.getTransform().getClass()); + try { + // Override the transform using the transform service. + res = + updateTransformViaTransformService( + urn, appliedPTransform, payloadTranslator, pipeline, res); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + } + } + // Validate that translation didn't produce an invalid pipeline. PipelineValidator.validate(res); return res; } + private static int findAvailablePort() throws IOException { + ServerSocket s = new ServerSocket(0); + try { + return s.getLocalPort(); + } finally { + s.close(); + try { + // Some systems don't free the port for future use immediately. + Thread.sleep(100); + } catch (InterruptedException exn) { + // ignore + } + } + } + + // Override the given transform to the version available in a new transform service. + private static < + InputT extends PInput, + OutputT extends POutput, + TransformT extends org.apache.beam.sdk.transforms.PTransform> + RunnerApi.Pipeline updateTransformViaTransformService( + String urn, + AppliedPTransform< + PInput, + POutput, + org.apache.beam.sdk.transforms.PTransform> + appliedPTransform, + TransformPayloadTranslator< + org.apache.beam.sdk.transforms.PTransform> + originalPayloadTranslator, + Pipeline pipeline, + RunnerApi.Pipeline runnerAPIpipeline) + throws IOException { + ExternalTranslationOptions externalTranslationOptions = + pipeline.getOptions().as(ExternalTranslationOptions.class); + + // Config row to re-construct the transform within the transform service. + Row configRow = originalPayloadTranslator.toConfigRow(appliedPTransform.getTransform()); + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + try { + RowCoder.of(configRow.getSchema()).encode(configRow, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + + // Java expansion serivice able to identify and expand transforms that includes the construction + // config provided here. + ExternalTransforms.ExternalConfigurationPayload payload = + ExternalTransforms.ExternalConfigurationPayload.newBuilder() + .setSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), true)) + .setPayload(outputStream.toByteString()) + .build(); + + String serviceAddress = null; + TransformServiceLauncher service = null; + try { + if (externalTranslationOptions.getTransformServiceAddress() != null) { + serviceAddress = externalTranslationOptions.getTransformServiceAddress(); + } else if (externalTranslationOptions.getTransformServiceBeamVersion() != null) { + String projectName = UUID.randomUUID().toString(); + service = TransformServiceLauncher.forProject(projectName, findAvailablePort()); + service.setBeamVersion(externalTranslationOptions.getTransformServiceBeamVersion()); + + // Starting the transform service. + service.start(); + // Waiting the service to be ready. + service.waitTillUp(15000); + } else { + throw new IllegalArgumentException( + "Either option TransformServiceAddress or option TransformServiceBeamVersion should be provided to override a transform using the transform service"); + } + + if (serviceAddress == null) { + throw new IllegalArgumentException( + "Cannot override the transform " + + urn + + " since a valid transform service address could not be determined"); + } + + // Creating an ExternalTransform and expanding it using the transform service. + // Input will be the same input provided to the transform bing overridden. + ExpandableTransform externalTransform = + (ExpandableTransform) + External.of(urn, payload.toByteArray(), serviceAddress); + + PCollectionTuple input = PCollectionTuple.empty(pipeline); + for (TupleTag tag : (Set>) appliedPTransform.getInputs().keySet()) { + PCollection pc = appliedPTransform.getInputs().get(tag); + if (pc == null) { + throw new IllegalArgumentException( + "Input of transform " + appliedPTransform + " with tag " + tag + " was null."); + } + input = input.and(tag, (PCollection) pc); + } + POutput output = externalTransform.expand((InputT) input); + + // Outputs of the transform being overridden. + Map, PCollection> originalOutputs = appliedPTransform.getOutputs(); + + // After expansion some transforms might still refer to the output of the already overridden + // transform as their input. + // Such inputs have to be overridden to use the output of the new upgraded transform. + Map inputReplacements = new HashMap<>(); + + // Will contain the outputs of the upgraded transform. + Map, PCollection> newOutputs = new HashMap<>(); + + if (output instanceof PCollectionTuple) { + newOutputs.putAll(((PCollectionTuple) output).getAll()); + for (Map.Entry, PCollection> entry : newOutputs.entrySet()) { + if (entry == null) { + throw new IllegalArgumentException( + "Found unexpected null entry when iterating the outputs of expanded " + + "ExpandableTransform " + + externalTransform); + } + if (!appliedPTransform.getOutputs().containsKey(entry.getKey())) { + throw new RuntimeException( + "Could not find the tag " + entry.getKey() + " in the original set of outputs"); + } + PCollection originalOutputPc = originalOutputs.get(entry.getKey()); + if (originalOutputPc == null) { + throw new IllegalArgumentException( + "Original output of transform " + + appliedPTransform + + " with tag " + + entry.getKey() + + " was null"); + } + inputReplacements.put(originalOutputPc.getName(), entry.getValue().getName()); + } + } else if (output instanceof PCollection) { + newOutputs.put(new TupleTag<>("temp_main_tag"), (PCollection) output); + inputReplacements.put( + originalOutputs.get(originalOutputs.keySet().iterator().next()).getName(), + ((PCollection) output).getName()); + } else { + throw new RuntimeException("Unexpected output type"); + } + + // We create a new AppliedPTransform to represent the upgraded transform and register it in an + // SdkComponents object. + AppliedPTransform updatedAppliedPTransform = + AppliedPTransform.of( + appliedPTransform.getFullName() + "_external", + appliedPTransform.getInputs(), + newOutputs, + externalTransform, + externalTransform.getResourceHints(), + appliedPTransform.getPipeline()); + SdkComponents updatedComponents = + SdkComponents.create( + runnerAPIpipeline.getComponents(), runnerAPIpipeline.getRequirementsList()); + String updatedTransformId = + updatedComponents.registerPTransform(updatedAppliedPTransform, Collections.emptyList()); + RunnerApi.Components updatedRunnerApiComponents = updatedComponents.toComponents(); + + // Recording input updates to the transforms to refer to the upgraded transform instead of the + // old one. + // Also recording the newly generated id of the old (overridden) transform in the + // updatedRunnerApiComponents. + Map> transformInputUpdates = new HashMap<>(); + List oldTransformIds = new ArrayList<>(); + updatedRunnerApiComponents + .getTransformsMap() + .forEach( + (transformId, transform) -> { + // Mapping from existing key to new value. + Map updatedInputMap = new HashMap<>(); + for (Map.Entry entry : transform.getInputsMap().entrySet()) { + if (inputReplacements.containsKey(entry.getValue())) { + updatedInputMap.put(entry.getKey(), inputReplacements.get(entry.getValue())); + } + } + for (Map.Entry entry : transform.getOutputsMap().entrySet()) { + if (inputReplacements.containsKey(entry.getValue()) + && urn.equals(transform.getSpec().getUrn())) { + oldTransformIds.add(transformId); + } + } + if (updatedInputMap.size() > 0) { + transformInputUpdates.put(transformId, updatedInputMap); + } + }); + // There should be only one recorded old (upgraded) transform. + if (oldTransformIds.size() != 1) { + throw new IOException( + "Expected exactly one transform to be updated by " + + oldTransformIds.size() + + " were updated."); + } + String oldTransformId = oldTransformIds.get(0); + + // Updated list of root transforms (in case a root was upgraded). + List updaterRootTransformIds = new ArrayList<>(); + updaterRootTransformIds.addAll(runnerAPIpipeline.getRootTransformIdsList()); + if (updaterRootTransformIds.contains(oldTransformId)) { + updaterRootTransformIds.remove(oldTransformId); + updaterRootTransformIds.add(updatedTransformId); + } + + // Generating the updated list of transforms. + // Also updates the input references to refer to the upgraded transform. + // Also updates the sub-transform reference to refer to the new transform. + Map updatedTransforms = new HashMap<>(); + updatedRunnerApiComponents + .getTransformsMap() + .forEach( + (transformId, transform) -> { + if (transformId.equals(oldTransformId)) { + // Do not include the old (upgraded) transform. + return; + } + PTransform.Builder transformBuilder = transform.toBuilder(); + if (transformInputUpdates.containsKey(transformId)) { + Map inputUpdates = transformInputUpdates.get(transformId); + transformBuilder + .getInputsMap() + .forEach( + (key, value) -> { + if (inputUpdates.containsKey(key)) { + transformBuilder.putInputs(key, inputUpdates.get(key)); + } + }); + } + if (transform.getSubtransformsList().contains(oldTransformId)) { + List updatedSubTransformsList = new ArrayList<>(); + updatedSubTransformsList.addAll(transform.getSubtransformsList()); + updatedSubTransformsList.remove(oldTransformId); + updatedSubTransformsList.add(updatedTransformId); + transformBuilder.clearSubtransforms(); + transformBuilder.addAllSubtransforms(updatedSubTransformsList); + } + updatedTransforms.put(transformId, transformBuilder.build()); + }); + + // Generating components with the updated list of transforms without including the old + // (upgraded) transform. + updatedRunnerApiComponents = + updatedRunnerApiComponents + .toBuilder() + .putAllTransforms(updatedTransforms) + .removeTransforms(oldTransformId) + .build(); + + // Generating the updated pipeline. + RunnerApi.Pipeline updatedPipeline = + RunnerApi.Pipeline.newBuilder() + .setComponents(updatedRunnerApiComponents) + .addAllRequirements(updatedComponents.requirements()) + .addAllRootTransformIds(updaterRootTransformIds) + .build(); + + return updatedPipeline; + } catch (TimeoutException e) { + throw new IOException(e); + } finally { + if (service != null) { + service.shutdown(); + } + } + } + + // Find all AppliedPTransforms that represent transforms with the given URN. + @SuppressWarnings({ + // Pre-registered 'knownTranslators' are defined as raw types. + "rawtypes" + }) + private static < + InputT extends PInput, + OutputT extends POutput, + TransformT extends org.apache.beam.sdk.transforms.PTransform> + List> findAppliedPTransforms( + String urn, + Pipeline pipeline, + Map< + Class, + TransformPayloadTranslator> + knownTranslators) { + + List> appliedPTransforms = new ArrayList<>(); + pipeline.traverseTopologically( + new PipelineVisitor.Defaults() { + + void findMatchingAppliedPTransform(Node node) { + org.apache.beam.sdk.transforms.PTransform transform = node.getTransform(); + if (transform == null) { + return; + } + if (knownTranslators.containsKey(transform.getClass())) { + TransformPayloadTranslator translator = + knownTranslators.get(transform.getClass()); + if (translator.getUrn() != null && translator.getUrn().equals(urn)) { + appliedPTransforms.add( + (AppliedPTransform) + node.toAppliedPTransform(pipeline)); + } + } + } + + @Override + public void leaveCompositeTransform(Node node) { + findMatchingAppliedPTransform(node); + } + + @Override + public void visitPrimitiveTransform(Node node) { + findMatchingAppliedPTransform(node); + } + }); + + return appliedPTransforms; + } + private static RunnerApi.Pipeline elideDeprecatedViews(RunnerApi.Pipeline pipeline) { // Record data on CreateView operations. Set viewTransforms = new HashSet<>(); diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReadTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReadTranslation.java index 40a7205b5d57..f04d25509593 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReadTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReadTranslation.java @@ -154,7 +154,7 @@ public static TransformPayloadTranslator create() { private UnboundedReadPayloadTranslator() {} @Override - public String getUrn(SplittableParDo.PrimitiveUnboundedRead transform) { + public String getUrn() { return PTransformTranslation.READ_TRANSFORM_URN; } @@ -181,7 +181,7 @@ public static TransformPayloadTranslator create() { private BoundedReadPayloadTranslator() {} @Override - public String getUrn(SplittableParDo.PrimitiveBoundedRead transform) { + public String getUrn() { return PTransformTranslation.READ_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReshuffleTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReshuffleTranslation.java index 98d39c8ff0ac..bd91673f3818 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReshuffleTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReshuffleTranslation.java @@ -38,7 +38,7 @@ public class ReshuffleTranslation { static class ReshuffleTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(Reshuffle transform) { + public String getUrn() { return PTransformTranslation.RESHUFFLE_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index 42c9e523e965..5ea2c4968dd9 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -395,7 +395,7 @@ public static TransformPayloadTranslator create() { private ProcessKeyedElementsTranslator() {} @Override - public String getUrn(ProcessKeyedElements transform) { + public String getUrn() { return PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java index aa582bf14f3c..53bb324d03fa 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java @@ -168,7 +168,7 @@ static TestStream.Event eventFromProto( /** A translator registered to translate {@link TestStream} objects to protobuf representation. */ static class TestStreamTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(TestStream transform) { + public String getUrn() { return TEST_STREAM_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowIntoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowIntoTranslation.java index 1b3aa50c7b44..294d89308a31 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowIntoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowIntoTranslation.java @@ -30,7 +30,6 @@ import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.transforms.windowing.Window.Assign; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.InvalidProtocolBufferException; import org.checkerframework.checker.nullness.qual.Nullable; @@ -47,7 +46,7 @@ public class WindowIntoTranslation { static class WindowAssignTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(Assign transform) { + public String getUrn() { return PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN; } @@ -116,7 +115,7 @@ public static TransformPayloadTranslator create() { private WindowIntoPayloadTranslator() {} @Override - public String getUrn(Window.Assign transform) { + public String getUrn() { return PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java index cce140536114..3a23ed073776 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java @@ -276,7 +276,7 @@ public boolean isRunnerDeterminedSharding() { static class WriteFilesTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(WriteFiles transform) { + public String getUrn() { return WRITE_FILES_TRANSFORM_URN; } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 17aea34045ff..1686891594e4 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -2572,6 +2572,11 @@ public String getUrn(PTransform transform) { return "dataflow_stub:" + transform.getClass().getName(); } + @Override + public String getUrn() { + throw new UnsupportedOperationException("URN of DataflowPayloadTranslator depends on the transform. Please use 'getUrn(PTransform transform)' instead."); + } + @Override public RunnerApi.FunctionSpec translate( AppliedPTransform> application, SdkComponents components) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java index facbbb3f1b44..140858d88c04 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java @@ -157,7 +157,7 @@ public static PTransformTranslation.TransformPayloadTranslator create() { private PayloadTranslator() {} @Override - public String getUrn(ParDoSingle transform) { + public String getUrn() { return PAR_DO_TRANSFORM_URN; } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 3c50ae6019f8..978e57313e4a 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -167,6 +167,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PValues; +import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.WindowingStrategy; @@ -1632,7 +1633,7 @@ public PCollection expand(PCollection input) { private static class TestTransformTranslator implements TransformPayloadTranslator { @Override - public String getUrn(TestTransform transform) { + public String getUrn() { return "test_transform"; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java index 079379953cd9..2c96f7c6b1f0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java @@ -333,6 +333,10 @@ public Iterable getElements() { return elems; } + public @Nullable Coder getCoder() { + return coder.isPresent() ? coder.get() : null; + } + @Override public PCollection expand(PBegin input) { Coder coder; diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index 6b52f8d1245e..a1d4826e0791 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -18,10 +18,12 @@ package org.apache.beam.sdk.expansion.service; import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.apache.beam.runners.core.construction.PTransformTranslation.READ_TRANSFORM_URN; import static org.apache.beam.runners.core.construction.resources.PipelineResources.detectClassPathResourcesToStage; import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; import com.google.auto.service.AutoService; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; @@ -45,10 +47,12 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.runners.core.construction.Environments; +import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; import org.apache.beam.runners.core.construction.PipelineTranslation; import org.apache.beam.runners.core.construction.RehydratedComponents; import org.apache.beam.runners.core.construction.SdkComponents; import org.apache.beam.runners.core.construction.SplittableParDo; +import org.apache.beam.runners.core.construction.TransformPayloadTranslatorRegistrar; import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; @@ -89,8 +93,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Converter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; @@ -178,6 +184,65 @@ public List getDependencies( } } + List deprecatedTransformURNs = ImmutableList.of(READ_TRANSFORM_URN); + for (TransformPayloadTranslatorRegistrar registrar : + ServiceLoader.load(TransformPayloadTranslatorRegistrar.class)) { + for (Map.Entry, ? extends TransformPayloadTranslator> + entry : registrar.getTransformPayloadTranslators().entrySet()) { + @Initialized TransformPayloadTranslator translator = entry.getValue(); + if (translator == null) { + continue; + } + + String urn = null; + try { + urn = translator.getUrn(); + if (urn == null) { + LOG.debug( + "Could not load the TransformPayloadTranslator " + + translator + + " to the Expansion Service since it did not produce a unique URN."); + continue; + } + } catch (Exception e) { + LOG.info( + "Could not load the TransformPayloadTranslator " + + translator + + " to the Expansion Service."); + continue; + } + + if (deprecatedTransformURNs.contains(urn)) { + continue; + } + final String finalUrn = urn; + TransformProvider transformProvider = + spec -> { + try { + ExternalConfigurationPayload payload = + ExternalConfigurationPayload.parseFrom(spec.getPayload()); + Row configRow = + RowCoder.of(SchemaTranslation.schemaFromProto(payload.getSchema())) + .decode(new ByteArrayInputStream(payload.getPayload().toByteArray())); + PTransform transformFromRow = translator.fromConfigRow(configRow); + if (transformFromRow != null) { + return transformFromRow; + } else { + throw new RuntimeException( + String.format( + "A transform cannot be initiated using the provided config row %s and the TransformPayloadTranslator %s", + configRow, translator)); + } + } catch (Exception e) { + throw new RuntimeException( + String.format("Failed to build transform %s from spec %s", finalUrn, spec), + e); + } + }; + builder.put(finalUrn, transformProvider); + } + } + return builder.build(); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubSubPayloadTranslation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubSubPayloadTranslation.java index c8214529d580..4722a3833fa9 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubSubPayloadTranslation.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubSubPayloadTranslation.java @@ -52,10 +52,7 @@ static class PubSubReadPayloadTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(Read.Unbounded transform) { - if (!(transform.getSource() instanceof PubsubUnboundedSource.PubsubSource)) { - return null; - } + public String getUrn() { return PTransformTranslation.PUBSUB_READ; } @@ -106,7 +103,7 @@ public RunnerApi.FunctionSpec translate( static class PubSubWritePayloadTranslator implements TransformPayloadTranslator { @Override - public String getUrn(PubsubUnboundedSink.PubsubSink transform) { + public String getUrn() { return PTransformTranslation.PUBSUB_WRITE; } @@ -140,7 +137,7 @@ public RunnerApi.FunctionSpec translate( static class PubSubDynamicWritePayloadTranslator implements TransformPayloadTranslator { @Override - public String getUrn(PubsubUnboundedSink.PubsubDynamicSink transform) { + public String getUrn() { return PTransformTranslation.PUBSUB_WRITE_DYNAMIC; } From 3edbd62129d94ba415f768ce0adaf7c7c4770e37 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Wed, 20 Sep 2023 16:11:51 -0700 Subject: [PATCH 2/8] Addresses reviewer comments --- .../core/construction/CreateTranslation.java | 128 ------ .../runners/core/construction/External.java | 10 +- .../construction/PTransformTranslation.java | 63 ++- .../construction/PipelineTranslation.java | 377 +---------------- .../core/construction/TransformUpgrader.java | 380 ++++++++++++++++++ .../construction/TransformUpgraderTest.java | 355 ++++++++++++++++ .../apache/beam/sdk/transforms/Create.java | 4 - .../expansion/service/ExpansionService.java | 2 +- 8 files changed, 805 insertions(+), 514 deletions(-) delete mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java create mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java create mode 100644 runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java deleted file mode 100644 index 076a5e437e5a..000000000000 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreateTranslation.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core.construction; - -import com.google.auto.service.AutoService; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; -import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.Create.Values; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.Row; -import org.checkerframework.checker.nullness.qual.Nullable; - -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" -}) -public class CreateTranslation implements TransformPayloadTranslator> { - - Schema createConfigSchema = - Schema.builder() - .addArrayField("values", FieldType.BYTES) - .addByteArrayField("serialized_coder") - .build(); - - @Override - public String getUrn() { - return PTransformTranslation.CREATE_TRANSFORM_URN; - } - - @Override - public @Nullable FunctionSpec translate( - AppliedPTransform> application, SdkComponents components) throws IOException { - // Currently just returns an empty payload. - // We can implement an actual payload of runners start using this transform. - return FunctionSpec.newBuilder().setUrn(getUrn(application.getTransform())).build(); - } - - private byte[] toByteArray(Object object) { - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeObject(object); - return bos.toByteArray(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private Object fromByteArray(byte[] bytes) { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - ObjectInputStream in = new ObjectInputStream(bis)) { - return in.readObject(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public @Nullable Row toConfigRow(Values transform) { - List encodedElements = new ArrayList<>(); - transform - .getElements() - .forEach( - object -> { - encodedElements.add(toByteArray(object)); - }); - - byte[] serializedCoder = - transform.getCoder() != null ? toByteArray(transform.getCoder()) : new byte[] {}; - return Row.withSchema(createConfigSchema) - .withFieldValue("values", encodedElements) - .withFieldValue("serialized_coder", serializedCoder) - .build(); - } - - @Override - public Create.@Nullable Values fromConfigRow(Row configRow) { - Values transform = - Create.of( - configRow.getArray("values").stream() - .map(bytesValue -> fromByteArray((byte[]) bytesValue)) - .collect(Collectors.toList())); - byte[] serializedCoder = configRow.getBytes("serialized_coder"); - if (serializedCoder.length > 0) { - Coder coder = (Coder) fromByteArray(serializedCoder); - transform = transform.withCoder(coder); - } - return transform; - } - - @AutoService(TransformPayloadTranslatorRegistrar.class) - public static class Registrar implements TransformPayloadTranslatorRegistrar { - @Override - public Map, ? extends TransformPayloadTranslator> - getTransformPayloadTranslators() { - return Collections.singletonMap(Create.Values.class, new CreateTranslation()); - } - } -} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java index cedd8875751f..534a2b5fe0e6 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java @@ -295,7 +295,7 @@ public OutputT expand(InputT input) { response .getComponents() .toBuilder() - .putAllEnvironments(resolveArtifacts(newEnvironmentsWithDependencies)) + .putAllEnvironments(resolveArtifacts(newEnvironmentsWithDependencies, endpoint)) .build(); expandedTransform = response.getTransform(); expandedRequirements = response.getRequirementsList(); @@ -338,8 +338,8 @@ public OutputT expand(InputT input) { return toOutputCollection(outputMapBuilder.build()); } - private Map resolveArtifacts( - Map environments) { + static Map resolveArtifacts( + Map environments, Endpoints.ApiServiceDescriptor endpoint) { if (environments.size() == 0) { return environments; } @@ -367,7 +367,7 @@ private Map resolveArtifacts( } } - private RunnerApi.Environment resolveArtifacts( + private static RunnerApi.Environment resolveArtifacts( ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceBlockingStub retrievalStub, RunnerApi.Environment environment) throws IOException { @@ -378,7 +378,7 @@ private RunnerApi.Environment resolveArtifacts( .build(); } - private List resolveArtifacts( + private static List resolveArtifacts( ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceBlockingStub retrievalStub, List artifacts) throws IOException { diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 02d0ceb37e3b..8ecb5be91225 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -21,6 +21,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.io.IOException; +import java.io.ObjectOutputStream; import java.util.Collection; import java.util.Collections; import java.util.Comparator; @@ -37,10 +38,13 @@ import org.apache.beam.runners.core.construction.ExternalTranslation.ExternalTranslator; import org.apache.beam.runners.core.construction.ParDoTranslation.ParDoTranslator; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.util.common.ReflectHelpers.ObjectsClassComparator; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; @@ -55,6 +59,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Utilities for converting {@link PTransform PTransforms} to {@link RunnerApi Runner API protocol @@ -66,6 +72,9 @@ "keyfor" }) // TODO(https://github.com/apache/beam/issues/20497) public class PTransformTranslation { + + private static final Logger LOG = LoggerFactory.getLogger(PTransformTranslation.class); + // We specifically copy the values here so that they can be used in switch case statements // and we validate that the value matches the actual URN in the static block below. @@ -85,6 +94,10 @@ public class PTransformTranslation { public static final ImmutableSet RUNNER_IMPLEMENTED_TRANSFORMS = ImmutableSet.of(GROUP_BY_KEY_TRANSFORM_URN, IMPULSE_TRANSFORM_URN); + public static final String CONFIG_ROW_KEY = "config_row"; + + public static final String CONFIG_ROW_SCHEMA_KEY = "config_row_schema"; + // DeprecatedPrimitives /** * @deprecated SDKs should move away from creating `Read` transforms and migrate to using Impulse @@ -388,9 +401,9 @@ public RunnerApi.PTransform translate( * Translates a set of registered transforms whose content only differs based by differences in * their {@link FunctionSpec}s and URNs. */ - public static class KnownTransformPayloadTranslator> + private static class KnownTransformPayloadTranslator> implements TransformTranslator { - public static final Map, TransformPayloadTranslator> + private static final Map, TransformPayloadTranslator> KNOWN_PAYLOAD_TRANSLATORS = loadTransformPayloadTranslators(); private static Map, TransformPayloadTranslator> @@ -437,10 +450,9 @@ public RunnerApi.PTransform translate( RunnerApi.PTransform.Builder transformBuilder = translateAppliedPTransform(appliedPTransform, subtransforms, components); - FunctionSpec spec = - KNOWN_PAYLOAD_TRANSLATORS - .get(appliedPTransform.getTransform().getClass()) - .translate(appliedPTransform, components); + TransformPayloadTranslator payloadTranslator = + KNOWN_PAYLOAD_TRANSLATORS.get(appliedPTransform.getTransform().getClass()); + FunctionSpec spec = payloadTranslator.translate(appliedPTransform, components); if (spec != null) { transformBuilder.setSpec(spec); @@ -463,6 +475,38 @@ public RunnerApi.PTransform translate( } } } + + Row configRow = null; + try { + configRow = payloadTranslator.toConfigRow(appliedPTransform.getTransform()); + } catch (UnsupportedOperationException e) { + // Optional toConfigRow() has not been implemented. We can just ignore. + } catch (Exception e) { + LOG.warn( + "Could not attach the config row for transform " + + appliedPTransform.getTransform().getName() + + ": " + + e); + // Ignoring the error and continuing with the translation since attaching config rows is + // optional. + } + if (configRow != null) { + ByteStringOutputStream rowOutputStream = new ByteStringOutputStream(); + try { + RowCoder.of(configRow.getSchema()).encode(configRow, rowOutputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + transformBuilder.putAnnotations(CONFIG_ROW_KEY, rowOutputStream.toByteString()); + + ByteStringOutputStream schemaOutputStream = new ByteStringOutputStream(); + try (ObjectOutputStream schemaObjOut = new ObjectOutputStream(schemaOutputStream)) { + schemaObjOut.writeObject(SchemaTranslation.schemaToProto(configRow.getSchema(), true)); + schemaObjOut.flush(); + transformBuilder.putAnnotations(CONFIG_ROW_SCHEMA_KEY, schemaOutputStream.toByteString()); + } + } + return transformBuilder.build(); } } @@ -531,9 +575,16 @@ default String getUrn(T transform) { return getUrn(); } + /** */ /** * Translates the given transform represented by the provided {@code AppliedPTransform} to a * {@code FunctionSpec} with a URN and a payload. + * + * @param application an {@code AppliedPTransform} that includes the transform to be expanded. + * @param components components of the pipeline that includes the transform. + * @return a generated spec for the transform to be included in the pipeline proto. If return + * value is null, transform should include an empty spec. + * @throws IOException */ @Nullable FunctionSpec translate(AppliedPTransform application, SdkComponents components) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java index 4d4c625154c8..e39a38a74c2c 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java @@ -18,39 +18,20 @@ package org.apache.beam.runners.core.construction; import java.io.IOException; -import java.net.ServerSocket; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Set; -import java.util.UUID; -import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; -import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.model.pipeline.v1.RunnerApi; -import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; -import org.apache.beam.runners.core.construction.External.ExpandableTransform; -import org.apache.beam.runners.core.construction.PTransformTranslation.KnownTransformPayloadTranslator; -import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; import org.apache.beam.runners.core.construction.graph.PipelineValidator; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; -import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy.Node; -import org.apache.beam.sdk.schemas.SchemaTranslation; -import org.apache.beam.sdk.transformservice.launcher.TransformServiceLauncher; -import org.apache.beam.sdk.util.ByteStringOutputStream; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ListMultimap; @@ -126,43 +107,13 @@ public void visitPrimitiveTransform(Node node) { pipeline.getOptions().as(ExternalTranslationOptions.class); List urnsToOverride = externalTranslationOptions.getTransformsToOverride(); if (urnsToOverride.size() > 0) { - // We use PTransformPayloadTranslators and the Transform Service to re-generate the pipeline - // proto components for the updated transform and update the pipeline proto. - Map transforms = res.getComponents().getTransformsMap(); - List alreadyCheckedURns = new ArrayList<>(); - for (Entry entry : transforms.entrySet()) { - String urn = entry.getValue().getSpec().getUrn(); - if (!alreadyCheckedURns.contains(urn) && urnsToOverride.contains(urn)) { - alreadyCheckedURns.add(urn); - // All transforms in the pipeline with the given urns have to be overridden. - List< - AppliedPTransform< - PInput, - POutput, - org.apache.beam.sdk.transforms.PTransform>> - appliedPTransforms = - findAppliedPTransforms( - urn, pipeline, KnownTransformPayloadTranslator.KNOWN_PAYLOAD_TRANSLATORS); - for (AppliedPTransform< - PInput, - POutput, - org.apache.beam.sdk.transforms.PTransform> - appliedPTransform : appliedPTransforms) { - TransformPayloadTranslator< - org.apache.beam.sdk.transforms.PTransform> - payloadTranslator = - KnownTransformPayloadTranslator.KNOWN_PAYLOAD_TRANSLATORS.get( - appliedPTransform.getTransform().getClass()); - try { - // Override the transform using the transform service. - res = - updateTransformViaTransformService( - urn, appliedPTransform, payloadTranslator, pipeline, res); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - } + try (TransformUpgrader upgrader = TransformUpgrader.of()) { + res = + upgrader.upgradeTransformsViaTransformService( + res, urnsToOverride, externalTranslationOptions); + } catch (Exception e) { + throw new RuntimeException( + "Could not override the transforms with URNs " + urnsToOverride, e); } } @@ -171,320 +122,6 @@ public void visitPrimitiveTransform(Node node) { return res; } - private static int findAvailablePort() throws IOException { - ServerSocket s = new ServerSocket(0); - try { - return s.getLocalPort(); - } finally { - s.close(); - try { - // Some systems don't free the port for future use immediately. - Thread.sleep(100); - } catch (InterruptedException exn) { - // ignore - } - } - } - - // Override the given transform to the version available in a new transform service. - private static < - InputT extends PInput, - OutputT extends POutput, - TransformT extends org.apache.beam.sdk.transforms.PTransform> - RunnerApi.Pipeline updateTransformViaTransformService( - String urn, - AppliedPTransform< - PInput, - POutput, - org.apache.beam.sdk.transforms.PTransform> - appliedPTransform, - TransformPayloadTranslator< - org.apache.beam.sdk.transforms.PTransform> - originalPayloadTranslator, - Pipeline pipeline, - RunnerApi.Pipeline runnerAPIpipeline) - throws IOException { - ExternalTranslationOptions externalTranslationOptions = - pipeline.getOptions().as(ExternalTranslationOptions.class); - - // Config row to re-construct the transform within the transform service. - Row configRow = originalPayloadTranslator.toConfigRow(appliedPTransform.getTransform()); - ByteStringOutputStream outputStream = new ByteStringOutputStream(); - try { - RowCoder.of(configRow.getSchema()).encode(configRow, outputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } - - // Java expansion serivice able to identify and expand transforms that includes the construction - // config provided here. - ExternalTransforms.ExternalConfigurationPayload payload = - ExternalTransforms.ExternalConfigurationPayload.newBuilder() - .setSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), true)) - .setPayload(outputStream.toByteString()) - .build(); - - String serviceAddress = null; - TransformServiceLauncher service = null; - try { - if (externalTranslationOptions.getTransformServiceAddress() != null) { - serviceAddress = externalTranslationOptions.getTransformServiceAddress(); - } else if (externalTranslationOptions.getTransformServiceBeamVersion() != null) { - String projectName = UUID.randomUUID().toString(); - service = TransformServiceLauncher.forProject(projectName, findAvailablePort()); - service.setBeamVersion(externalTranslationOptions.getTransformServiceBeamVersion()); - - // Starting the transform service. - service.start(); - // Waiting the service to be ready. - service.waitTillUp(15000); - } else { - throw new IllegalArgumentException( - "Either option TransformServiceAddress or option TransformServiceBeamVersion should be provided to override a transform using the transform service"); - } - - if (serviceAddress == null) { - throw new IllegalArgumentException( - "Cannot override the transform " - + urn - + " since a valid transform service address could not be determined"); - } - - // Creating an ExternalTransform and expanding it using the transform service. - // Input will be the same input provided to the transform bing overridden. - ExpandableTransform externalTransform = - (ExpandableTransform) - External.of(urn, payload.toByteArray(), serviceAddress); - - PCollectionTuple input = PCollectionTuple.empty(pipeline); - for (TupleTag tag : (Set>) appliedPTransform.getInputs().keySet()) { - PCollection pc = appliedPTransform.getInputs().get(tag); - if (pc == null) { - throw new IllegalArgumentException( - "Input of transform " + appliedPTransform + " with tag " + tag + " was null."); - } - input = input.and(tag, (PCollection) pc); - } - POutput output = externalTransform.expand((InputT) input); - - // Outputs of the transform being overridden. - Map, PCollection> originalOutputs = appliedPTransform.getOutputs(); - - // After expansion some transforms might still refer to the output of the already overridden - // transform as their input. - // Such inputs have to be overridden to use the output of the new upgraded transform. - Map inputReplacements = new HashMap<>(); - - // Will contain the outputs of the upgraded transform. - Map, PCollection> newOutputs = new HashMap<>(); - - if (output instanceof PCollectionTuple) { - newOutputs.putAll(((PCollectionTuple) output).getAll()); - for (Map.Entry, PCollection> entry : newOutputs.entrySet()) { - if (entry == null) { - throw new IllegalArgumentException( - "Found unexpected null entry when iterating the outputs of expanded " - + "ExpandableTransform " - + externalTransform); - } - if (!appliedPTransform.getOutputs().containsKey(entry.getKey())) { - throw new RuntimeException( - "Could not find the tag " + entry.getKey() + " in the original set of outputs"); - } - PCollection originalOutputPc = originalOutputs.get(entry.getKey()); - if (originalOutputPc == null) { - throw new IllegalArgumentException( - "Original output of transform " - + appliedPTransform - + " with tag " - + entry.getKey() - + " was null"); - } - inputReplacements.put(originalOutputPc.getName(), entry.getValue().getName()); - } - } else if (output instanceof PCollection) { - newOutputs.put(new TupleTag<>("temp_main_tag"), (PCollection) output); - inputReplacements.put( - originalOutputs.get(originalOutputs.keySet().iterator().next()).getName(), - ((PCollection) output).getName()); - } else { - throw new RuntimeException("Unexpected output type"); - } - - // We create a new AppliedPTransform to represent the upgraded transform and register it in an - // SdkComponents object. - AppliedPTransform updatedAppliedPTransform = - AppliedPTransform.of( - appliedPTransform.getFullName() + "_external", - appliedPTransform.getInputs(), - newOutputs, - externalTransform, - externalTransform.getResourceHints(), - appliedPTransform.getPipeline()); - SdkComponents updatedComponents = - SdkComponents.create( - runnerAPIpipeline.getComponents(), runnerAPIpipeline.getRequirementsList()); - String updatedTransformId = - updatedComponents.registerPTransform(updatedAppliedPTransform, Collections.emptyList()); - RunnerApi.Components updatedRunnerApiComponents = updatedComponents.toComponents(); - - // Recording input updates to the transforms to refer to the upgraded transform instead of the - // old one. - // Also recording the newly generated id of the old (overridden) transform in the - // updatedRunnerApiComponents. - Map> transformInputUpdates = new HashMap<>(); - List oldTransformIds = new ArrayList<>(); - updatedRunnerApiComponents - .getTransformsMap() - .forEach( - (transformId, transform) -> { - // Mapping from existing key to new value. - Map updatedInputMap = new HashMap<>(); - for (Map.Entry entry : transform.getInputsMap().entrySet()) { - if (inputReplacements.containsKey(entry.getValue())) { - updatedInputMap.put(entry.getKey(), inputReplacements.get(entry.getValue())); - } - } - for (Map.Entry entry : transform.getOutputsMap().entrySet()) { - if (inputReplacements.containsKey(entry.getValue()) - && urn.equals(transform.getSpec().getUrn())) { - oldTransformIds.add(transformId); - } - } - if (updatedInputMap.size() > 0) { - transformInputUpdates.put(transformId, updatedInputMap); - } - }); - // There should be only one recorded old (upgraded) transform. - if (oldTransformIds.size() != 1) { - throw new IOException( - "Expected exactly one transform to be updated by " - + oldTransformIds.size() - + " were updated."); - } - String oldTransformId = oldTransformIds.get(0); - - // Updated list of root transforms (in case a root was upgraded). - List updaterRootTransformIds = new ArrayList<>(); - updaterRootTransformIds.addAll(runnerAPIpipeline.getRootTransformIdsList()); - if (updaterRootTransformIds.contains(oldTransformId)) { - updaterRootTransformIds.remove(oldTransformId); - updaterRootTransformIds.add(updatedTransformId); - } - - // Generating the updated list of transforms. - // Also updates the input references to refer to the upgraded transform. - // Also updates the sub-transform reference to refer to the new transform. - Map updatedTransforms = new HashMap<>(); - updatedRunnerApiComponents - .getTransformsMap() - .forEach( - (transformId, transform) -> { - if (transformId.equals(oldTransformId)) { - // Do not include the old (upgraded) transform. - return; - } - PTransform.Builder transformBuilder = transform.toBuilder(); - if (transformInputUpdates.containsKey(transformId)) { - Map inputUpdates = transformInputUpdates.get(transformId); - transformBuilder - .getInputsMap() - .forEach( - (key, value) -> { - if (inputUpdates.containsKey(key)) { - transformBuilder.putInputs(key, inputUpdates.get(key)); - } - }); - } - if (transform.getSubtransformsList().contains(oldTransformId)) { - List updatedSubTransformsList = new ArrayList<>(); - updatedSubTransformsList.addAll(transform.getSubtransformsList()); - updatedSubTransformsList.remove(oldTransformId); - updatedSubTransformsList.add(updatedTransformId); - transformBuilder.clearSubtransforms(); - transformBuilder.addAllSubtransforms(updatedSubTransformsList); - } - updatedTransforms.put(transformId, transformBuilder.build()); - }); - - // Generating components with the updated list of transforms without including the old - // (upgraded) transform. - updatedRunnerApiComponents = - updatedRunnerApiComponents - .toBuilder() - .putAllTransforms(updatedTransforms) - .removeTransforms(oldTransformId) - .build(); - - // Generating the updated pipeline. - RunnerApi.Pipeline updatedPipeline = - RunnerApi.Pipeline.newBuilder() - .setComponents(updatedRunnerApiComponents) - .addAllRequirements(updatedComponents.requirements()) - .addAllRootTransformIds(updaterRootTransformIds) - .build(); - - return updatedPipeline; - } catch (TimeoutException e) { - throw new IOException(e); - } finally { - if (service != null) { - service.shutdown(); - } - } - } - - // Find all AppliedPTransforms that represent transforms with the given URN. - @SuppressWarnings({ - // Pre-registered 'knownTranslators' are defined as raw types. - "rawtypes" - }) - private static < - InputT extends PInput, - OutputT extends POutput, - TransformT extends org.apache.beam.sdk.transforms.PTransform> - List> findAppliedPTransforms( - String urn, - Pipeline pipeline, - Map< - Class, - TransformPayloadTranslator> - knownTranslators) { - - List> appliedPTransforms = new ArrayList<>(); - pipeline.traverseTopologically( - new PipelineVisitor.Defaults() { - - void findMatchingAppliedPTransform(Node node) { - org.apache.beam.sdk.transforms.PTransform transform = node.getTransform(); - if (transform == null) { - return; - } - if (knownTranslators.containsKey(transform.getClass())) { - TransformPayloadTranslator translator = - knownTranslators.get(transform.getClass()); - if (translator.getUrn() != null && translator.getUrn().equals(urn)) { - appliedPTransforms.add( - (AppliedPTransform) - node.toAppliedPTransform(pipeline)); - } - } - } - - @Override - public void leaveCompositeTransform(Node node) { - findMatchingAppliedPTransform(node); - } - - @Override - public void visitPrimitiveTransform(Node node) { - findMatchingAppliedPTransform(node); - } - }); - - return appliedPTransforms; - } - private static RunnerApi.Pipeline elideDeprecatedViews(RunnerApi.Pipeline pipeline) { // Record data on CreateView operations. Set viewTransforms = new HashSet<>(); diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java new file mode 100644 index 000000000000..00ea9da71e78 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.net.ServerSocket; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.model.pipeline.v1.SchemaApi; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transformservice.launcher.TransformServiceLauncher; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannelBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; + +/** + * A utility class that allows upgrading transforms of a given pipeline using the Beam Transform + * Service. + */ +public class TransformUpgrader implements AutoCloseable { + private static final String UPGRADE_NAMESPACE = "transform:upgrade:"; + + private ExpansionServiceClientFactory clientFactory; + + private static final ExpansionServiceClientFactory DEFAULT = + DefaultExpansionServiceClientFactory.create( + endPoint -> ManagedChannelBuilder.forTarget(endPoint.getUrl()).usePlaintext().build()); + + private TransformUpgrader(ExpansionServiceClientFactory clientFactory) { + this.clientFactory = clientFactory; + } + + public static TransformUpgrader of() { + return new TransformUpgrader(DEFAULT); + } + + public static TransformUpgrader of(ExpansionServiceClientFactory clientFactory) { + return new TransformUpgrader(clientFactory); + } + + /** + * Upgrade identified transforms in a given pipeline using the Transform Service. + * + * @param pipeline the pipeline proto. + * @param urnsToOverride URNs of the transforms to be overridden. + * @param options options for determining the transform service to use. + * @return pipelines with transforms upgraded using the Transform Service. + * @throws Exception + */ + public RunnerApi.Pipeline upgradeTransformsViaTransformService( + RunnerApi.Pipeline pipeline, List urnsToOverride, ExternalTranslationOptions options) + throws Exception { + List transformsToOverride = + pipeline.getComponents().getTransformsMap().entrySet().stream() + .filter( + entry -> { + String urn = entry.getValue().getSpec().getUrn(); + if (urn != null && urnsToOverride.contains(urn)) { + return true; + } + return false; + }) + .map( + entry -> { + return entry.getKey(); + }) + .collect(Collectors.toList()); + + String serviceAddress; + TransformServiceLauncher service = null; + try { + if (options.getTransformServiceAddress() != null) { + serviceAddress = options.getTransformServiceAddress(); + } else if (options.getTransformServiceBeamVersion() != null) { + String projectName = UUID.randomUUID().toString(); + int port = findAvailablePort(); + service = TransformServiceLauncher.forProject(projectName, port); + service.setBeamVersion(options.getTransformServiceBeamVersion()); + + // Starting the transform service. + service.start(); + service.waitTillUp(40); + serviceAddress = "localhost:" + Integer.toString(port); + System.out.println("Done waiting ..."); + } else { + throw new IllegalArgumentException( + "Either option TransformServiceAddress or option TransformServiceBeamVersion should be " + + "provided to override a transform using the transform service"); + } + + Endpoints.ApiServiceDescriptor expansionServiceEndpoint = + Endpoints.ApiServiceDescriptor.newBuilder().setUrl(serviceAddress).build(); + + for (String transformId : transformsToOverride) { + pipeline = + updateTransformViaTransformService(pipeline, transformId, expansionServiceEndpoint); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (service != null) { + try { + service.shutdown(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + return pipeline; + } + + private < + InputT extends PInput, + OutputT extends POutput, + TransformT extends org.apache.beam.sdk.transforms.PTransform> + RunnerApi.Pipeline updateTransformViaTransformService( + RunnerApi.Pipeline runnerAPIpipeline, + String transformId, + Endpoints.ApiServiceDescriptor transformServiceEndpoint) + throws Exception { + PTransform transformToUpgrade = + runnerAPIpipeline.getComponents().getTransformsMap().get(transformId); + if (transformToUpgrade == null) { + throw new Exception("Could not find a transform with the ID " + transformId); + } + + ByteString configRowBytes = + transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_KEY); + ByteString configRowSchemaBytes = + transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_SCHEMA_KEY); + SchemaApi.Schema configRowSchemaProto; + try { + configRowSchemaProto = + (SchemaApi.Schema) + new ObjectInputStream(new ByteArrayInputStream(configRowSchemaBytes.toByteArray())) + .readObject(); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + + Row configRow = + RowCoder.of(SchemaTranslation.schemaFromProto(configRowSchemaProto)) + .decode(new ByteArrayInputStream(configRowBytes.toByteArray())); + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + try { + RowCoder.of(configRow.getSchema()).encode(configRow, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + + ExternalTransforms.ExternalConfigurationPayload payload = + ExternalTransforms.ExternalConfigurationPayload.newBuilder() + .setSchema(configRowSchemaProto) + .setPayload(outputStream.toByteString()) + .build(); + + RunnerApi.PTransform.Builder ptransformBuilder = + RunnerApi.PTransform.newBuilder() + .setUniqueName(transformToUpgrade.getUniqueName() + "_external") + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(transformToUpgrade.getSpec().getUrn()) + .setPayload(ByteString.copyFrom(payload.toByteArray())) + .build()); + + for (Map.Entry entry : transformToUpgrade.getInputsMap().entrySet()) { + ptransformBuilder.putInputs(entry.getKey(), entry.getValue()); + } + for (Map.Entry entry : transformToUpgrade.getOutputsMap().entrySet()) { + ptransformBuilder.putOutputs(entry.getKey(), entry.getValue()); + } + + ExpansionApi.ExpansionRequest.Builder requestBuilder = + ExpansionApi.ExpansionRequest.newBuilder(); + ExpansionApi.ExpansionRequest request = + requestBuilder + .setComponents(runnerAPIpipeline.getComponents()) + .setTransform(ptransformBuilder.build()) + .setNamespace(UPGRADE_NAMESPACE) + .build(); + + ExpansionApi.ExpansionResponse response = + clientFactory.getExpansionServiceClient(transformServiceEndpoint).expand(request); + + if (!Strings.isNullOrEmpty(response.getError())) { + throw new IOException(String.format("expansion service error: %s", response.getError())); + } + + Map newEnvironmentsWithDependencies = + response.getComponents().getEnvironmentsMap().entrySet().stream() + .filter( + kv -> + !runnerAPIpipeline.getComponents().getEnvironmentsMap().containsKey(kv.getKey()) + && kv.getValue().getDependenciesCount() != 0) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + RunnerApi.Components expandedComponents = + response + .getComponents() + .toBuilder() + .putAllEnvironments( + External.ExpandableTransform.resolveArtifacts( + newEnvironmentsWithDependencies, transformServiceEndpoint)) + .build(); + RunnerApi.PTransform expandedTransform = response.getTransform(); + List expandedRequirements = response.getRequirementsList(); + + RunnerApi.Components.Builder newComponentsBuilder = expandedComponents.toBuilder(); + + // Some transforms may refer to already overridden transform as one of their input. We record + // such occurrences and correct them by referring to the upgraded transform instead. + Collection oldOutputs = transformToUpgrade.getOutputsMap().values(); + Map inputReplacements = new HashMap<>(); + if (transformToUpgrade.getOutputsMap().size() == 1) { + inputReplacements.put( + oldOutputs.iterator().next(), + expandedTransform.getOutputsMap().values().iterator().next()); + } else { + for (Map.Entry entry : transformToUpgrade.getOutputsMap().entrySet()) { + if (expandedTransform.getOutputsMap().keySet().contains(entry.getKey())) { + throw new Exception( + "Original transform did not have an output with tag " + + entry.getKey() + + " but upgraded transform did."); + } + String newOutput = expandedTransform.getOutputsMap().get(entry.getKey()); + if (newOutput == null) { + throw new Exception( + "Could not find an output with tag " + + entry.getKey() + + " for the transform " + + expandedTransform); + } + inputReplacements.put(entry.getValue(), newOutput); + } + } + + String newTransformId = transformId + "_upgraded"; + + // The list of obsolete (overridden) transforms that should be removed from the pipeline + // produced by this method. + List transformsToRemove = new ArrayList<>(); + recursivelyFindSubTransforms( + transformId, runnerAPIpipeline.getComponents(), transformsToRemove); + + Map updatedExpandedTransformMap = + expandedComponents.getTransformsMap().entrySet().stream() + .filter( + entry -> { + // Do not include already overridden transforms. + return !transformsToRemove.contains(entry.getKey()); + }) + .collect( + Collectors.toMap( + entry -> entry.getKey(), + entry -> { + // Fix inputs + Map inputsMap = entry.getValue().getInputsMap(); + PTransform.Builder transformBuilder = entry.getValue().toBuilder(); + if (!Collections.disjoint(inputsMap.values(), inputReplacements.keySet())) { + Map updatedInputsMap = new HashMap<>(); + for (Map.Entry inputEntry : inputsMap.entrySet()) { + String updaterValue = + inputReplacements.containsKey(inputEntry.getValue()) + ? inputReplacements.get(inputEntry.getValue()) + : inputEntry.getValue(); + updatedInputsMap.put(inputEntry.getKey(), updaterValue); + } + transformBuilder.clearInputs(); + transformBuilder.putAllInputs(updatedInputsMap); + } + + // Fix sub-transforms + if (entry.getValue().getSubtransformsList().contains(transformId)) { + List updatedSubTransforms = + entry.getValue().getSubtransformsList().stream() + .map( + subtransformId -> { + return subtransformId.equals(transformId) + ? newTransformId + : subtransformId; + }) + .collect(Collectors.toList()); + transformBuilder.clearSubtransforms(); + transformBuilder.addAllSubtransforms(updatedSubTransforms); + } + + return transformBuilder.build(); + })); + + newComponentsBuilder.clearTransforms(); + newComponentsBuilder.putAllTransforms(updatedExpandedTransformMap); + newComponentsBuilder.putTransforms(newTransformId, expandedTransform); + + // We fix the root in case the overridden transform was one of the roots. + List rootTransformIds = + runnerAPIpipeline.getRootTransformIdsList().stream() + .map(id -> id.equals(transformId) ? newTransformId : id) + .collect(Collectors.toList()); + + RunnerApi.Pipeline.Builder newRunnerAPIPipelineBuilder = runnerAPIpipeline.toBuilder(); + newRunnerAPIPipelineBuilder.clearComponents(); + newRunnerAPIPipelineBuilder.setComponents(newComponentsBuilder.build()); + + newRunnerAPIPipelineBuilder.addAllRequirements(expandedRequirements); + newRunnerAPIPipelineBuilder.clearRootTransformIds(); + newRunnerAPIPipelineBuilder.addAllRootTransformIds(rootTransformIds); + + return newRunnerAPIPipelineBuilder.build(); + } + + private static void recursivelyFindSubTransforms( + String transformId, RunnerApi.Components components, List results) { + results.add(transformId); + PTransform transform = components.getTransformsMap().get(transformId); + if (transform == null) { + throw new IllegalArgumentException("Could not find a transform with id " + transformId); + } + List subTransforms = transform.getSubtransformsList(); + if (subTransforms != null) { + for (String subTransformId : subTransforms) { + recursivelyFindSubTransforms(subTransformId, components, results); + } + } + } + + private static int findAvailablePort() throws IOException { + ServerSocket s = new ServerSocket(0); + try { + return s.getLocalPort(); + } finally { + s.close(); + try { + // Some systems don't free the port for future use immediately. + Thread.sleep(100); + } catch (InterruptedException exn) { + // ignore + } + } + } + + @Override + public void close() throws Exception { + clientFactory.close(); + } +} diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java new file mode 100644 index 000000000000..abf6ccd81585 --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.construction; + +import static org.junit.Assert.assertEquals; + +import com.google.auto.service.AutoService; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.transforms.ToString; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for TransformServiceBasedOverride. */ +@RunWith(JUnit4.class) +public class TransformUpgraderTest { + static class TestTransform extends PTransform, PCollection> { + private int testParam; + + public TestTransform(int testParam) { + this.testParam = testParam; + } + + @Override + public PCollection expand(PCollection input) { + return input.apply( + MapElements.via( + new SimpleFunction() { + @Override + public Integer apply(Integer input) { + return input * testParam; + } + })); + } + + public Integer getTestParam() { + return testParam; + } + } + + static class TestTransformPayloadTranslator + implements PTransformTranslation.TransformPayloadTranslator { + + static final String URN = "beam:transform:test:transform_to_update"; + + Schema configRowSchema = Schema.builder().addInt32Field("multiplier").build(); + + @Override + public String getUrn() { + return URN; + } + + @Override + public TestTransform fromConfigRow(Row configRow) { + return new TestTransform(configRow.getInt32("multiplier")); + } + + @Override + public Row toConfigRow(TestTransform transform) { + return Row.withSchema(configRowSchema) + .withFieldValue("multiplier", transform.getTestParam()) + .build(); + } + + @Override + public RunnerApi.@Nullable FunctionSpec translate( + AppliedPTransform application, SdkComponents components) + throws IOException { + + int testParam = application.getTransform().getTestParam(); + + FunctionSpec.Builder specBuilder = FunctionSpec.newBuilder(); + specBuilder.setUrn(getUrn()); + + ByteStringOutputStream byteStringOut = new ByteStringOutputStream(); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteStringOut); + objectOutputStream.writeObject(testParam); + objectOutputStream.flush(); + specBuilder.setPayload(byteStringOut.toByteString()); + + return specBuilder.build(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class Registrar implements TransformPayloadTranslatorRegistrar { + @Override + public Map, TestTransformPayloadTranslator> + getTransformPayloadTranslators() { + return Collections.singletonMap(TestTransform.class, new TestTransformPayloadTranslator()); + } + } + + static class TestTransform2 extends TestTransform { + public TestTransform2(int testParam) { + super(testParam); + } + } + + static class TestTransformPayloadTranslator2 extends TestTransformPayloadTranslator { + static final String URN = "beam:transform:test:transform_to_update2"; + + @Override + public String getUrn() { + return URN; + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class Registrar2 implements TransformPayloadTranslatorRegistrar { + @Override + public Map, TestTransformPayloadTranslator2> + getTransformPayloadTranslators() { + return Collections.singletonMap(TestTransform2.class, new TestTransformPayloadTranslator2()); + } + } + + static class TestExpansionServiceClientFactory implements ExpansionServiceClientFactory { + ExpansionApi.ExpansionResponse response; + + @Override + public ExpansionServiceClient getExpansionServiceClient( + Endpoints.ApiServiceDescriptor endpoint) { + return new ExpansionServiceClient() { + @Override + public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest request) { + RunnerApi.Components.Builder responseComponents = request.getComponents().toBuilder(); + RunnerApi.PTransform transformToUpgrade = + request.getComponents().getTransformsMap().get("TransformUpgraderTest-TestTransform"); + if (transformToUpgrade == null) { + transformToUpgrade = + request + .getComponents() + .getTransformsMap() + .get("TransformUpgraderTest-TestTransform2"); + } + + Integer oldParam; + try { + ByteArrayInputStream byteArrayInputStream = + new ByteArrayInputStream(transformToUpgrade.getSpec().getPayload().toByteArray()); + ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream); + oldParam = (Integer) objectInputStream.readObject(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + System.out.println("Oldparam: " + oldParam); + RunnerApi.PTransform.Builder upgradedTransform = transformToUpgrade.toBuilder(); + FunctionSpec.Builder specBuilder = upgradedTransform.getSpecBuilder(); + + ByteStringOutputStream byteStringOutputStream = new ByteStringOutputStream(); + try { + ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteStringOutputStream); + objectOutputStream.writeObject(oldParam * 2); + objectOutputStream.flush(); + specBuilder.setPayload(byteStringOutputStream.toByteString()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + upgradedTransform.setSpec(specBuilder.build()); + + response = + ExpansionApi.ExpansionResponse.newBuilder() + .setComponents(responseComponents.build()) + .setTransform(upgradedTransform.build()) + .build(); + return response; + } + + @Override + public ExpansionApi.DiscoverSchemaTransformResponse discover( + ExpansionApi.DiscoverSchemaTransformRequest request) { + return null; + } + + @Override + public void close() throws Exception { + // do nothing + } + }; + } + + @Override + public void close() throws Exception { + // do nothing + } + } + + private void validateTestParam(RunnerApi.PTransform updatedTestTransform, Integer expectedValue) { + Integer updatedParam; + try { + ByteArrayInputStream byteArrayInputStream = + new ByteArrayInputStream(updatedTestTransform.getSpec().getPayload().toByteArray()); + ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream); + updatedParam = (Integer) objectInputStream.readObject(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + assertEquals(Integer.valueOf(expectedValue), updatedParam); + + System.out.println("Updated param: " + updatedParam); + } + + @Test + public void testTransformUpgrade() throws Exception { + Pipeline pipeline = Pipeline.create(); + pipeline + .apply(Create.of(1, 2, 3)) + .apply(new TestTransform(2)) + .apply(ToString.elements()) + .apply(TextIO.write().to("dummyfilename")); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, false); + ExternalTranslationOptions options = + PipelineOptionsFactory.create().as(ExternalTranslationOptions.class); + List urnsToOverride = ImmutableList.of(TestTransformPayloadTranslator.URN); + options.setTransformsToOverride(urnsToOverride); + options.setTransformServiceAddress("dummyaddress"); + + RunnerApi.Pipeline upgradedPipelineProto = + TransformUpgrader.of(new TestExpansionServiceClientFactory()) + .upgradeTransformsViaTransformService(pipelineProto, urnsToOverride, options); + + RunnerApi.PTransform upgradedTransform = + upgradedPipelineProto + .getComponents() + .getTransformsMap() + .get("TransformUpgraderTest-TestTransform_upgraded"); + + validateTestParam(upgradedTransform, 4); + } + + @Test + public void testTransformUpgradeMultipleOccurrences() throws Exception { + Pipeline pipeline = Pipeline.create(); + pipeline + .apply(Create.of(1, 2, 3)) + .apply(new TestTransform(2)) + .apply(ToString.elements()) + .apply(TextIO.write().to("dummyfilename")); + pipeline + .apply(Create.of(1, 2, 3)) + .apply(new TestTransform(2)) + .apply(ToString.elements()) + .apply(TextIO.write().to("dummyfilename")); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, false); + ExternalTranslationOptions options = + PipelineOptionsFactory.create().as(ExternalTranslationOptions.class); + List urnsToOverride = ImmutableList.of(TestTransformPayloadTranslator.URN); + options.setTransformsToOverride(urnsToOverride); + options.setTransformServiceAddress("dummyaddress"); + + RunnerApi.Pipeline upgradedPipelineProto = + TransformUpgrader.of(new TestExpansionServiceClientFactory()) + .upgradeTransformsViaTransformService(pipelineProto, urnsToOverride, options); + + RunnerApi.PTransform upgradedTransform1 = + upgradedPipelineProto + .getComponents() + .getTransformsMap() + .get("TransformUpgraderTest-TestTransform_upgraded"); + validateTestParam(upgradedTransform1, 4); + + RunnerApi.PTransform upgradedTransform2 = + upgradedPipelineProto + .getComponents() + .getTransformsMap() + .get("TransformUpgraderTest-TestTransform2_upgraded"); + validateTestParam(upgradedTransform2, 4); + } + + @Test + public void testTransformUpgradeMultipleURNs() throws Exception { + Pipeline pipeline = Pipeline.create(); + pipeline + .apply(Create.of(1, 2, 3)) + .apply(new TestTransform(2)) + .apply(ToString.elements()) + .apply(TextIO.write().to("dummyfilename")); + pipeline + .apply(Create.of(1, 2, 3)) + .apply(new TestTransform2(2)) + .apply(ToString.elements()) + .apply(TextIO.write().to("dummyfilename")); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, false); + ExternalTranslationOptions options = + PipelineOptionsFactory.create().as(ExternalTranslationOptions.class); + List urnsToOverride = + ImmutableList.of(TestTransformPayloadTranslator.URN, TestTransformPayloadTranslator2.URN); + options.setTransformsToOverride(urnsToOverride); + options.setTransformServiceAddress("dummyaddress"); + + RunnerApi.Pipeline upgradedPipelineProto = + TransformUpgrader.of(new TestExpansionServiceClientFactory()) + .upgradeTransformsViaTransformService(pipelineProto, urnsToOverride, options); + + RunnerApi.PTransform upgradedTransform1 = + upgradedPipelineProto + .getComponents() + .getTransformsMap() + .get("TransformUpgraderTest-TestTransform_upgraded"); + validateTestParam(upgradedTransform1, 4); + + RunnerApi.PTransform upgradedTransform2 = + upgradedPipelineProto + .getComponents() + .getTransformsMap() + .get("TransformUpgraderTest-TestTransform2_upgraded"); + validateTestParam(upgradedTransform2, 4); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java index 2c96f7c6b1f0..079379953cd9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java @@ -333,10 +333,6 @@ public Iterable getElements() { return elems; } - public @Nullable Coder getCoder() { - return coder.isPresent() ? coder.get() : null; - } - @Override public PCollection expand(PBegin input) { Coder coder; diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index a1d4826e0791..ec53e3f11e43 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -194,7 +194,7 @@ public List getDependencies( continue; } - String urn = null; + String urn; try { urn = translator.getUrn(); if (urn == null) { From 7a4271f9f4b270f89788fbe6248d11c9ba447138 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Wed, 20 Sep 2023 16:31:23 -0700 Subject: [PATCH 3/8] Reduce visibility of the test-only constructor --- .../beam/runners/core/construction/TransformUpgrader.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java index 00ea9da71e78..8abb606df8e7 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java @@ -44,6 +44,7 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannelBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; /** @@ -67,7 +68,8 @@ public static TransformUpgrader of() { return new TransformUpgrader(DEFAULT); } - public static TransformUpgrader of(ExpansionServiceClientFactory clientFactory) { + @VisibleForTesting + static TransformUpgrader of(ExpansionServiceClientFactory clientFactory) { return new TransformUpgrader(clientFactory); } From e589f03f8eca047c6ae90e8061e39cbd1da9bd1e Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Thu, 21 Sep 2023 11:00:50 -0700 Subject: [PATCH 4/8] Fix compile errors --- .../runners/flink/FlinkStreamingTransformTranslators.java | 2 +- .../beam/runners/samza/translation/SamzaPublishView.java | 2 +- .../translation/streaming/StreamingTransformTranslator.java | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index b725bfbb8d40..6d42d0c3b485 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -1446,7 +1446,7 @@ private static class CreateStreamingFlinkViewPayloadTranslator private CreateStreamingFlinkViewPayloadTranslator() {} @Override - public String getUrn(CreateStreamingFlinkView.CreateFlinkPCollectionView transform) { + public String getUrn() { return CreateStreamingFlinkView.CREATE_STREAMING_FLINK_VIEW_URN; } } diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishView.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishView.java index 9a50d3d579ac..a3ebbffef9a8 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishView.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishView.java @@ -59,7 +59,7 @@ static class SamzaPublishViewPayloadTranslator SamzaPublishViewPayloadTranslator() {} @Override - public String getUrn(SamzaPublishView transform) { + public String getUrn() { return SAMZA_PUBLISH_VIEW_URN; } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index eaa267375db3..266b67798a22 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -636,7 +636,7 @@ private static class SparkConsoleIOWriteUnboundedPayloadTranslator ConsoleIO.Write.Unbound> { @Override - public String getUrn(ConsoleIO.Write.Unbound transform) { + public String getUrn() { return ConsoleIO.Write.Unbound.TRANSFORM_URN; } } @@ -645,7 +645,7 @@ private static class SparkCreateStreamPayloadTranslator extends PTransformTranslation.TransformPayloadTranslator.NotSerializable> { @Override - public String getUrn(CreateStream transform) { + public String getUrn() { return CreateStream.TRANSFORM_URN; } } From 6a85535b6ed6a6dc890fd565f88147965681ef84 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Thu, 21 Sep 2023 13:48:02 -0700 Subject: [PATCH 5/8] Fix spotless --- .../java/org/apache/beam/runners/dataflow/DataflowRunner.java | 3 ++- .../org/apache/beam/runners/dataflow/DataflowRunnerTest.java | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 1686891594e4..26548038a1df 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -2574,7 +2574,8 @@ public String getUrn(PTransform transform) { @Override public String getUrn() { - throw new UnsupportedOperationException("URN of DataflowPayloadTranslator depends on the transform. Please use 'getUrn(PTransform transform)' instead."); + throw new UnsupportedOperationException( + "URN of DataflowPayloadTranslator depends on the transform. Please use 'getUrn(PTransform transform)' instead."); } @Override diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 978e57313e4a..078f25e0e38e 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -167,7 +167,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PValues; -import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.WindowingStrategy; From 732bd32c168b190ed052029787ee52f5005d3d44 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Thu, 21 Sep 2023 17:52:12 -0700 Subject: [PATCH 6/8] Addressing reviewer comments --- .../construction/PTransformTranslation.java | 27 ++-- .../core/construction/TransformUpgrader.java | 130 +++++------------- .../construction/TransformUpgraderTest.java | 32 +++-- 3 files changed, 71 insertions(+), 118 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 8ecb5be91225..8f415e718e95 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -21,7 +21,6 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.io.IOException; -import java.io.ObjectOutputStream; import java.util.Collection; import java.util.Collections; import java.util.Comparator; @@ -44,13 +43,14 @@ import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.common.ReflectHelpers.ObjectsClassComparator; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; @@ -491,20 +491,15 @@ public RunnerApi.PTransform translate( // optional. } if (configRow != null) { - ByteStringOutputStream rowOutputStream = new ByteStringOutputStream(); - try { - RowCoder.of(configRow.getSchema()).encode(configRow, rowOutputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } - transformBuilder.putAnnotations(CONFIG_ROW_KEY, rowOutputStream.toByteString()); - - ByteStringOutputStream schemaOutputStream = new ByteStringOutputStream(); - try (ObjectOutputStream schemaObjOut = new ObjectOutputStream(schemaOutputStream)) { - schemaObjOut.writeObject(SchemaTranslation.schemaToProto(configRow.getSchema(), true)); - schemaObjOut.flush(); - transformBuilder.putAnnotations(CONFIG_ROW_SCHEMA_KEY, schemaOutputStream.toByteString()); - } + transformBuilder.putAnnotations( + CONFIG_ROW_KEY, + ByteString.copyFrom( + CoderUtils.encodeToByteArray(RowCoder.of(configRow.getSchema()), configRow))); + + transformBuilder.putAnnotations( + CONFIG_ROW_SCHEMA_KEY, + ByteString.copyFrom( + SchemaTranslation.schemaToProto(configRow.getSchema(), true).toByteArray())); } return transformBuilder.build(); diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java index 8abb606df8e7..9a9373932eb2 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java @@ -17,9 +17,7 @@ */ package org.apache.beam.runners.core.construction; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.ObjectInputStream; import java.net.ServerSocket; import java.util.ArrayList; import java.util.Collection; @@ -28,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import org.apache.beam.model.expansion.v1.ExpansionApi; import org.apache.beam.model.pipeline.v1.Endpoints; @@ -35,13 +34,9 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; import org.apache.beam.model.pipeline.v1.SchemaApi; -import org.apache.beam.sdk.coders.RowCoder; -import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.transformservice.launcher.TransformServiceLauncher; -import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannelBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -84,7 +79,7 @@ static TransformUpgrader of(ExpansionServiceClientFactory clientFactory) { */ public RunnerApi.Pipeline upgradeTransformsViaTransformService( RunnerApi.Pipeline pipeline, List urnsToOverride, ExternalTranslationOptions options) - throws Exception { + throws IOException, TimeoutException { List transformsToOverride = pipeline.getComponents().getTransformsMap().entrySet().stream() .filter( @@ -103,43 +98,35 @@ public RunnerApi.Pipeline upgradeTransformsViaTransformService( String serviceAddress; TransformServiceLauncher service = null; - try { - if (options.getTransformServiceAddress() != null) { - serviceAddress = options.getTransformServiceAddress(); - } else if (options.getTransformServiceBeamVersion() != null) { - String projectName = UUID.randomUUID().toString(); - int port = findAvailablePort(); - service = TransformServiceLauncher.forProject(projectName, port); - service.setBeamVersion(options.getTransformServiceBeamVersion()); - - // Starting the transform service. - service.start(); - service.waitTillUp(40); - serviceAddress = "localhost:" + Integer.toString(port); - System.out.println("Done waiting ..."); - } else { - throw new IllegalArgumentException( - "Either option TransformServiceAddress or option TransformServiceBeamVersion should be " - + "provided to override a transform using the transform service"); - } - Endpoints.ApiServiceDescriptor expansionServiceEndpoint = - Endpoints.ApiServiceDescriptor.newBuilder().setUrl(serviceAddress).build(); + if (options.getTransformServiceAddress() != null) { + serviceAddress = options.getTransformServiceAddress(); + } else if (options.getTransformServiceBeamVersion() != null) { + String projectName = UUID.randomUUID().toString(); + int port = findAvailablePort(); + service = TransformServiceLauncher.forProject(projectName, port); + service.setBeamVersion(options.getTransformServiceBeamVersion()); + + // Starting the transform service. + service.start(); + service.waitTillUp(-1); + serviceAddress = "localhost:" + Integer.toString(port); + } else { + throw new IllegalArgumentException( + "Either option TransformServiceAddress or option TransformServiceBeamVersion should be " + + "provided to override a transform using the transform service"); + } - for (String transformId : transformsToOverride) { - pipeline = - updateTransformViaTransformService(pipeline, transformId, expansionServiceEndpoint); - } - } catch (Exception e) { - throw new RuntimeException(e); + Endpoints.ApiServiceDescriptor expansionServiceEndpoint = + Endpoints.ApiServiceDescriptor.newBuilder().setUrl(serviceAddress).build(); + + for (String transformId : transformsToOverride) { + pipeline = + updateTransformViaTransformService(pipeline, transformId, expansionServiceEndpoint); } if (service != null) { - try { - service.shutdown(); - } catch (IOException e) { - throw new RuntimeException(e); - } + service.shutdown(); } return pipeline; @@ -153,41 +140,24 @@ RunnerApi.Pipeline updateTransformViaTransformService( RunnerApi.Pipeline runnerAPIpipeline, String transformId, Endpoints.ApiServiceDescriptor transformServiceEndpoint) - throws Exception { + throws IOException { PTransform transformToUpgrade = runnerAPIpipeline.getComponents().getTransformsMap().get(transformId); if (transformToUpgrade == null) { - throw new Exception("Could not find a transform with the ID " + transformId); + throw new IllegalArgumentException("Could not find a transform with the ID " + transformId); } ByteString configRowBytes = transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_KEY); ByteString configRowSchemaBytes = transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_SCHEMA_KEY); - SchemaApi.Schema configRowSchemaProto; - try { - configRowSchemaProto = - (SchemaApi.Schema) - new ObjectInputStream(new ByteArrayInputStream(configRowSchemaBytes.toByteArray())) - .readObject(); - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - - Row configRow = - RowCoder.of(SchemaTranslation.schemaFromProto(configRowSchemaProto)) - .decode(new ByteArrayInputStream(configRowBytes.toByteArray())); - ByteStringOutputStream outputStream = new ByteStringOutputStream(); - try { - RowCoder.of(configRow.getSchema()).encode(configRow, outputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } + SchemaApi.Schema configRowSchemaProto = + SchemaApi.Schema.parseFrom(configRowSchemaBytes.toByteArray()); ExternalTransforms.ExternalConfigurationPayload payload = ExternalTransforms.ExternalConfigurationPayload.newBuilder() .setSchema(configRowSchemaProto) - .setPayload(outputStream.toByteString()) + .setPayload(configRowBytes) .build(); RunnerApi.PTransform.Builder ptransformBuilder = @@ -219,7 +189,7 @@ RunnerApi.Pipeline updateTransformViaTransformService( clientFactory.getExpansionServiceClient(transformServiceEndpoint).expand(request); if (!Strings.isNullOrEmpty(response.getError())) { - throw new IOException(String.format("expansion service error: %s", response.getError())); + throw new RuntimeException(String.format("expansion service error: %s", response.getError())); } Map newEnvironmentsWithDependencies = @@ -243,8 +213,8 @@ RunnerApi.Pipeline updateTransformViaTransformService( RunnerApi.Components.Builder newComponentsBuilder = expandedComponents.toBuilder(); - // Some transforms may refer to already overridden transform as one of their input. We record - // such occurrences and correct them by referring to the upgraded transform instead. + // We record transforms that consume outputs of the old transform and update them to consume + // outputs of the new (upgraded) transform. Collection oldOutputs = transformToUpgrade.getOutputsMap().values(); Map inputReplacements = new HashMap<>(); if (transformToUpgrade.getOutputsMap().size() == 1) { @@ -254,14 +224,14 @@ RunnerApi.Pipeline updateTransformViaTransformService( } else { for (Map.Entry entry : transformToUpgrade.getOutputsMap().entrySet()) { if (expandedTransform.getOutputsMap().keySet().contains(entry.getKey())) { - throw new Exception( + throw new IllegalArgumentException( "Original transform did not have an output with tag " + entry.getKey() + " but upgraded transform did."); } String newOutput = expandedTransform.getOutputsMap().get(entry.getKey()); if (newOutput == null) { - throw new Exception( + throw new IllegalArgumentException( "Could not find an output with tag " + entry.getKey() + " for the transform " @@ -271,8 +241,6 @@ RunnerApi.Pipeline updateTransformViaTransformService( } } - String newTransformId = transformId + "_upgraded"; - // The list of obsolete (overridden) transforms that should be removed from the pipeline // produced by this method. List transformsToRemove = new ArrayList<>(); @@ -305,42 +273,18 @@ RunnerApi.Pipeline updateTransformViaTransformService( transformBuilder.clearInputs(); transformBuilder.putAllInputs(updatedInputsMap); } - - // Fix sub-transforms - if (entry.getValue().getSubtransformsList().contains(transformId)) { - List updatedSubTransforms = - entry.getValue().getSubtransformsList().stream() - .map( - subtransformId -> { - return subtransformId.equals(transformId) - ? newTransformId - : subtransformId; - }) - .collect(Collectors.toList()); - transformBuilder.clearSubtransforms(); - transformBuilder.addAllSubtransforms(updatedSubTransforms); - } - return transformBuilder.build(); })); newComponentsBuilder.clearTransforms(); newComponentsBuilder.putAllTransforms(updatedExpandedTransformMap); - newComponentsBuilder.putTransforms(newTransformId, expandedTransform); - - // We fix the root in case the overridden transform was one of the roots. - List rootTransformIds = - runnerAPIpipeline.getRootTransformIdsList().stream() - .map(id -> id.equals(transformId) ? newTransformId : id) - .collect(Collectors.toList()); + newComponentsBuilder.putTransforms(transformId, expandedTransform); RunnerApi.Pipeline.Builder newRunnerAPIPipelineBuilder = runnerAPIpipeline.toBuilder(); newRunnerAPIPipelineBuilder.clearComponents(); newRunnerAPIPipelineBuilder.setComponents(newComponentsBuilder.build()); newRunnerAPIPipelineBuilder.addAllRequirements(expandedRequirements); - newRunnerAPIPipelineBuilder.clearRootTransformIds(); - newRunnerAPIPipelineBuilder.addAllRootTransformIds(rootTransformIds); return newRunnerAPIPipelineBuilder.build(); } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java index abf6ccd81585..6620e780bc16 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformUpgraderTest.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.nio.charset.Charset; import java.util.Collections; import java.util.List; import java.util.Map; @@ -44,6 +45,7 @@ import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.junit.Test; @@ -166,13 +168,25 @@ public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest reque RunnerApi.Components.Builder responseComponents = request.getComponents().toBuilder(); RunnerApi.PTransform transformToUpgrade = request.getComponents().getTransformsMap().get("TransformUpgraderTest-TestTransform"); - if (transformToUpgrade == null) { + ByteString alreadyUpgraded = ByteString.empty(); + try { + alreadyUpgraded = transformToUpgrade.getAnnotationsOrThrow("already_upgraded"); + } catch (Exception e) { + // Ignore + } + if (!alreadyUpgraded.isEmpty()) { transformToUpgrade = request .getComponents() .getTransformsMap() .get("TransformUpgraderTest-TestTransform2"); } + if (!transformToUpgrade + .getSpec() + .getUrn() + .equals(request.getTransform().getSpec().getUrn())) { + throw new RuntimeException("Could not find a valid transform to upgrade"); + } Integer oldParam; try { @@ -184,7 +198,6 @@ public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest reque throw new RuntimeException(e); } - System.out.println("Oldparam: " + oldParam); RunnerApi.PTransform.Builder upgradedTransform = transformToUpgrade.toBuilder(); FunctionSpec.Builder specBuilder = upgradedTransform.getSpecBuilder(); @@ -199,6 +212,9 @@ public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest reque } upgradedTransform.setSpec(specBuilder.build()); + upgradedTransform.putAnnotations( + "already_upgraded", + ByteString.copyFrom("dummyvalue".getBytes(Charset.defaultCharset()))); response = ExpansionApi.ExpansionResponse.newBuilder() @@ -239,8 +255,6 @@ private void validateTestParam(RunnerApi.PTransform updatedTestTransform, Intege } assertEquals(Integer.valueOf(expectedValue), updatedParam); - - System.out.println("Updated param: " + updatedParam); } @Test @@ -267,7 +281,7 @@ public void testTransformUpgrade() throws Exception { upgradedPipelineProto .getComponents() .getTransformsMap() - .get("TransformUpgraderTest-TestTransform_upgraded"); + .get("TransformUpgraderTest-TestTransform"); validateTestParam(upgradedTransform, 4); } @@ -301,14 +315,14 @@ public void testTransformUpgradeMultipleOccurrences() throws Exception { upgradedPipelineProto .getComponents() .getTransformsMap() - .get("TransformUpgraderTest-TestTransform_upgraded"); + .get("TransformUpgraderTest-TestTransform"); validateTestParam(upgradedTransform1, 4); RunnerApi.PTransform upgradedTransform2 = upgradedPipelineProto .getComponents() .getTransformsMap() - .get("TransformUpgraderTest-TestTransform2_upgraded"); + .get("TransformUpgraderTest-TestTransform2"); validateTestParam(upgradedTransform2, 4); } @@ -342,14 +356,14 @@ public void testTransformUpgradeMultipleURNs() throws Exception { upgradedPipelineProto .getComponents() .getTransformsMap() - .get("TransformUpgraderTest-TestTransform_upgraded"); + .get("TransformUpgraderTest-TestTransform"); validateTestParam(upgradedTransform1, 4); RunnerApi.PTransform upgradedTransform2 = upgradedPipelineProto .getComponents() .getTransformsMap() - .get("TransformUpgraderTest-TestTransform2_upgraded"); + .get("TransformUpgraderTest-TestTransform2"); validateTestParam(upgradedTransform2, 4); } } From 3898c12292bcb66f4e4a23c70f667e2ba2edea2e Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Thu, 21 Sep 2023 21:54:11 -0700 Subject: [PATCH 7/8] Do not bundle Transform Service Launcher in the harness --- sdks/java/harness/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle index f157cbadee57..a2a99bc01e1f 100644 --- a/sdks/java/harness/build.gradle +++ b/sdks/java/harness/build.gradle @@ -29,6 +29,7 @@ dependencies { // :sdks:java:core and transitive dependencies provided project(path: ":model:pipeline", configuration: "shadow") provided project(path: ":sdks:java:core", configuration: "shadow") + provided project(path: ":sdks:java:transform-service:launcher", configuration: "shadow") provided library.java.joda_time provided library.java.slf4j_api provided library.java.vendored_grpc_1_54_0 From 55b74b593e565f919fd9bda788f0b659e8ff4bad Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Fri, 22 Sep 2023 16:53:56 -0700 Subject: [PATCH 8/8] Fix harness build and a fix for when a runner invokes toProto() multiple times --- .../runners/core/construction/TransformUpgrader.java | 12 ++++++++---- sdks/java/harness/build.gradle | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java index 9a9373932eb2..d657bb31b184 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java @@ -51,16 +51,20 @@ public class TransformUpgrader implements AutoCloseable { private ExpansionServiceClientFactory clientFactory; - private static final ExpansionServiceClientFactory DEFAULT = - DefaultExpansionServiceClientFactory.create( - endPoint -> ManagedChannelBuilder.forTarget(endPoint.getUrl()).usePlaintext().build()); + private TransformUpgrader() { + // Creating a default 'ExpansionServiceClientFactory' instance per 'TransformUpgrader' instance + // so that each instance can maintain a set of live channels and close them independently. + clientFactory = + DefaultExpansionServiceClientFactory.create( + endPoint -> ManagedChannelBuilder.forTarget(endPoint.getUrl()).usePlaintext().build()); + } private TransformUpgrader(ExpansionServiceClientFactory clientFactory) { this.clientFactory = clientFactory; } public static TransformUpgrader of() { - return new TransformUpgrader(DEFAULT); + return new TransformUpgrader(); } @VisibleForTesting diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle index a2a99bc01e1f..25d6b2ac4040 100644 --- a/sdks/java/harness/build.gradle +++ b/sdks/java/harness/build.gradle @@ -81,6 +81,7 @@ dependencies { implementation project(":runners:core-construction-java") implementation project(":runners:core-java") implementation project(":sdks:java:fn-execution") + permitUnusedDeclared project(path: ":sdks:java:transform-service:launcher") testImplementation library.java.junit testImplementation library.java.mockito_core shadowTestRuntimeClasspath project(path: ":sdks:java:core", configuration: "shadowTest")