From 81a50f2331b96277ffa9bb6359250c379d927d8b Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 11 Sep 2024 11:25:41 -0700 Subject: [PATCH 01/82] Add various utility meta-transforms to Beam. --- CHANGES.md | 2 + .../apache/beam/sdk/transforms/Flatten.java | 77 ++++++++++++++++ .../org/apache/beam/sdk/transforms/Tee.java | 91 +++++++++++++++++++ .../beam/sdk/transforms/FlattenTest.java | 27 ++++++ .../apache/beam/sdk/transforms/TeeTest.java | 84 +++++++++++++++++ sdks/python/apache_beam/transforms/core.py | 28 ++++++ .../apache_beam/transforms/ptransform_test.py | 12 +++ sdks/python/apache_beam/transforms/util.py | 34 +++++++ .../apache_beam/transforms/util_test.py | 29 ++++++ 9 files changed, 384 insertions(+) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java diff --git a/CHANGES.md b/CHANGES.md index d58ceffeb411..a9d6eeba10d8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,8 @@ * Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). * Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Add new meta-transform FlattenWith and Tee that allow one to introduce branching + without breaking the linear/chaining style of pipeline construction. ## Breaking Changes diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index 8da3cf71af9f..afc11353f1a5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableLikeCoder; import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; @@ -82,6 +83,82 @@ public static Iterables iterables() { return new Iterables<>(); } + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with a given a {@link + * PCollection} resulting in a {@link PCollection} containing all the elements of both {@link + * PCollection}s as its output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and + * {@code other} and then applying {@link #pCollections()}, but has the advantage that it can be + * more easily used inline. + * + *

Both {@cpde PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param other the other PCollection to flatten with the input + * @param the type of the elements in the input and output {@code PCollection}s. + */ + public static PTransform, PCollection> with(PCollection other) { + return new FlattenWithPCollection<>(other); + } + + /** Implementation of {@link #with(PCollection)}. */ + private static class FlattenWithPCollection + extends PTransform, PCollection> { + // We only need to access this at pipeline construction time. + private final transient PCollection other; + + public FlattenWithPCollection(PCollection other) { + this.other = other; + } + + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input).and(other).apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + } + + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with the output of + * another {@link PTransform} resulting in a {@link PCollection} containing all the elements of + * both the input {@link PCollection}s and the output of the given {@link PTransform} as its + * output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and the + * output of {@code other} and then applying {@link #pCollections()}, but has the advantage that + * it can be more easily used inline. + * + *

Both {@cpde PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param the type of the elements in the input and output {@code PCollection}s. + * @param other a PTransform whose ouptput should be flattened with the input + * @param the type of the elements in the input and output {@code PCollection}s. + */ + public static PTransform, PCollection> with( + PTransform> other) { + return new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input) + .and(input.getPipeline().apply(other)) + .apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + }; + } + /** * A {@link PTransform} that flattens a {@link PCollectionList} into a {@link PCollection} * containing all the elements of all the {@link PCollection}s in its input. Implements {@link diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java new file mode 100644 index 000000000000..bb65cbf94632 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.function.Consumer; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; + +/** + * A PTransform that returns its input, but also applies its input to an auxiliary PTransform, akin + * to the shell {@code tee} command. + * + *

This can be useful to write out or otherwise process an intermediate transform without + * breaking the linear flow of a chain of transforms, e.g. + * + *


+ * {@literal PCollection} input = ... ;
+ * {@literal PCollection} result =
+ *     {@literal input.apply(...)}
+ *     ...
+ *     {@literal input.apply(Tee.of(someSideTransform)}
+ *     ...
+ *     {@literal input.apply(...)};
+ * 
+ * + * @param the element type of the input PCollection + */ +public class Tee extends PTransform, PCollection> { + private final PTransform, ?> consumer; + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An additional PTransform that should process the input PCollection. Its output + * will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(PTransform, ?> consumer) { + return new Tee<>(consumer); + } + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An arbitrary {@link Consumer} that will be wrapped in a PTransform and applied + * to the input. Its output will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(Consumer> consumer) { + return of( + new PTransform, PCollectionTuple>() { + @Override + public PCollectionTuple expand(PCollection input) { + consumer.accept(input); + return PCollectionTuple.empty(input.getPipeline()); + } + }); + } + + private Tee(PTransform, ?> consumer) { + this.consumer = consumer; + } + + @Override + public PCollection expand(PCollection input) { + input.apply(consumer); + return input; + } + + @Override + protected String getKindString() { + return "Tee(" + consumer.getName() + ")"; + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java index 282a41bed0dc..7a02d95a5046 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java @@ -402,6 +402,32 @@ public void testFlattenWithDifferentInputAndOutputCoders2() { ///////////////////////////////////////////////////////////////////////////// + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPCollection() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("FlattenWithLines1", Flatten.with(p.apply("Create1", Create.of(LINES)))) + .apply("FlattenWithLines2", Flatten.with(p.apply("Create2", Create.of(LINES2)))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPTransform() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("Create1", Flatten.with(Create.of(LINES))) + .apply("Create2", Flatten.with(Create.of(LINES2))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + ///////////////////////////////////////////////////////////////////////////// + @Test @Category(NeedsRunner.class) public void testEqualWindowFnPropagation() { @@ -470,6 +496,7 @@ public void testIncompatibleWindowFnPropagationFailure() { public void testFlattenGetName() { Assert.assertEquals("Flatten.Iterables", Flatten.iterables().getName()); Assert.assertEquals("Flatten.PCollections", Flatten.pCollections().getName()); + Assert.assertEquals("Flatten.With", Flatten.with((PCollection) null).getName()); } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java new file mode 100644 index 000000000000..ee3a00c46caa --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimaps; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for Tee. */ +@RunWith(JUnit4.class) +public class TeeTest { + + @Rule public final transient TestPipeline p = TestPipeline.create(); + + @Test + @Category(NeedsRunner.class) + public void testTee() { + List elements = Arrays.asList("a", "b", "c"); + CollectToMemory collector = new CollectToMemory<>(); + PCollection output = p.apply(Create.of(elements)).apply(Tee.of(collector)); + + PAssert.that(output).containsInAnyOrder(elements); + p.run().waitUntilFinish(); + + // Here we assert that this "sink" had the correct side effects. + assertThat(collector.get(), containsInAnyOrder(elements.toArray(new String[3]))); + } + + private static class CollectToMemory extends PTransform, PCollection> { + + private static final Multimap ALL_ELEMENTS = + Multimaps.synchronizedMultimap(HashMultimap.create()); + + UUID uuid = UUID.randomUUID(); + + @Override + public PCollection expand(PCollection input) { + return input.apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + ALL_ELEMENTS.put(uuid, c.element()); + } + })); + } + + @SuppressWarnings("unchecked") + public Collection get() { + return (Collection) ALL_ELEMENTS.get(uuid); + } + } +} diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index d7415e8d8135..5671779e5811 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -103,6 +103,7 @@ 'Windowing', 'WindowInto', 'Flatten', + 'FlattenWith', 'Create', 'Impulse', 'RestrictionProvider', @@ -3836,6 +3837,33 @@ def from_runner_api_parameter( common_urns.primitives.FLATTEN.urn, None, Flatten.from_runner_api_parameter) +class FlattenWith(PTransform): + """A PTransform that flattens its input with other PCollections. + + This is equivalent to creating a tuple containing both the input and the + other PCollection(s), but has the advantage that it can be more easily used + inline. + + Root PTransforms can be passed as well as PCollections, in which case their + outputs will be flattened. + """ + def __init__(self, *others): + self._others = others + + def expand(self, pcoll): + pcolls = [pcoll] + for other in self._others: + if isinstance(other, pvalue.PCollection): + pcolls.append(other) + elif isinstance(other, PTransform): + pcolls.append(pcoll.pipeline | other) + else: + raise TypeError( + 'FlattenWith only takes other PCollections and PTransforms, ' + f'got {other}') + return tuple(pcolls) | Flatten() + + class Create(PTransform): """A transform that creates a PCollection from an iterable.""" def __init__(self, values, reshuffle=True): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index a51d5cd83d26..2c9037185286 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -787,6 +787,18 @@ def split_even_odd(element): assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even') assert_that(odd_length, equal_to(['BBB']), label='assert:odd') + def test_flatten_with(self): + with TestPipeline() as pipeline: + input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC']) + + result = ( + input + | 'WithPCollection' >> beam.FlattenWith(input | beam.Map(str.lower)) + | 'WithPTransform' >> beam.FlattenWith(beam.Create(['x', 'y']))) + + assert_that( + result, equal_to(['AA', 'BBB', 'CC', 'aa', 'bbb', 'cc', 'x', 'y'])) + def test_group_by_key_input_must_be_kv_pairs(self): with self.assertRaises(typehints.TypeCheckError) as e: with TestPipeline() as pipeline: diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index a27c7aca9e20..a03652de2496 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -30,6 +30,7 @@ import uuid from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import List from typing import Tuple @@ -44,6 +45,7 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSideInput +from apache_beam.pvalue import PCollection from apache_beam.transforms import window from apache_beam.transforms.combiners import CountCombineFn from apache_beam.transforms.core import CombinePerKey @@ -92,6 +94,7 @@ 'RemoveDuplicates', 'Reshuffle', 'ToString', + 'Tee', 'Values', 'WithKeys', 'GroupIntoBatches' @@ -1665,6 +1668,37 @@ def _process(element): return pcoll | FlatMap(_process) +@typehints.with_input_types(T) +@typehints.with_output_types(T) +class Tee(PTransform): + """A PTransform that returns its input, but also applies its input elsewhere. + + Similar to the shell {@code tee} command. This can be useful to write out or + otherwise process an intermediate transform without breaking the linear flow + of a chain of transforms, e.g.:: + + (input + | SomePTransform() + | ... + | Tee(SomeSideTransform()) + | ...) + """ + def __init__( + self, + *consumers: Union[PTransform[PCollection[T], Any], + Callable[[PCollection[T]], Any]]): + self._consumers = consumers + + def expand(self, input): + for consumer in self._consumers: + print("apply", consumer) + if callable(consumer): + _ = input | ptransform_fn(consumer)() + else: + _ = input | consumer + return input + + @typehints.with_input_types(T) @typehints.with_output_types(T) class WaitOn(PTransform): diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 74d9f438a5df..9c131504e6f4 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -19,6 +19,8 @@ # pytype: skip-file +import collections +import importlib import logging import math import random @@ -27,6 +29,7 @@ import unittest import warnings from datetime import datetime +from typing import Mapping import pytest import pytz @@ -1812,6 +1815,32 @@ def test_split_without_empty(self): assert_that(result, equal_to(expected_result)) +class TeeTest(unittest.TestCase): + _side_effects: Mapping[str, int] = collections.defaultdict(int) + + def test_tee(self): + # The imports here are to avoid issues with the class (and its attributes) + # possibly being pickled rather than referenced. + def cause_side_effect(element): + importlib.import_module(__name__).TeeTest._side_effects[element] += 1 + + def count_side_effects(element): + return importlib.import_module(__name__).TeeTest._side_effects[element] + + with TestPipeline() as p: + result = ( + p + | beam.Create(['a', 'b', 'c']) + | 'TeePTransform' >> beam.Tee(beam.Map(cause_side_effect)) + | 'TeeCallable' >> beam.Tee( + lambda pcoll: pcoll | beam.Map( + lambda element: cause_side_effect('X' + element)))) + assert_that(result, equal_to(['a', 'b', 'c'])) + + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(count_side_effects('Xa'), 1) + + class WaitOnTest(unittest.TestCase): def test_find(self): # We need shared reference that survives pickling. From 71d97b6896a7de7fdbb48f0f7835081e411964fc Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 17 Sep 2024 15:11:09 -0700 Subject: [PATCH 02/82] Add note about FlattenWith to the documentation. --- .../apache_beam/examples/snippets/snippets.py | 33 +++++++++++++++++++ .../examples/snippets/snippets_test.py | 6 ++++ .../en/documentation/programming-guide.md | 27 ++++++++++++++- 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index 715011d302d2..2636f7d2637d 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -1143,6 +1143,39 @@ def model_multiple_pcollections_flatten(contents, output_path): merged | beam.io.WriteToText(output_path) +def model_multiple_pcollections_flatten_with(contents, output_path): + """Merging a PCollection with FlattenWith.""" + some_hash_fn = lambda s: ord(s[0]) + partition_fn = lambda element, partitions: some_hash_fn(element) % partitions + import apache_beam as beam + with TestPipeline() as pipeline: # Use TestPipeline for testing. + + # Partition into deciles + partitioned = pipeline | beam.Create(contents) | beam.Partition( + partition_fn, 3) + pcoll1 = partitioned[0] + pcoll2 = partitioned[1] + pcoll3 = partitioned[2] + SomeTransform = lambda: beam.Map(lambda x: x) + SomeOtherTransform = lambda: beam.Map(lambda x: x) + + # Flatten them back into 1 + + # A collection of PCollection objects can be represented simply + # as a tuple (or list) of PCollections. + # (The SDK for Python has no separate type to store multiple + # PCollection objects, whether containing the same or different + # types.) + # [START model_multiple_pcollections_flatten_with] + merged = ( + pcoll1 + | SomeTransform() + | beam.FlattenWith(pcoll2, pcoll3) + | SomeOtherTransform()) + # [END model_multiple_pcollections_flatten_with] + merged | beam.io.WriteToText(output_path) + + def model_multiple_pcollections_partition(contents, output_path): """Splitting a PCollection with Partition.""" some_hash_fn = lambda s: ord(s[0]) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index e8cb8960cf4d..0560e9710f03 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -917,6 +917,12 @@ def test_model_multiple_pcollections_flatten(self): snippets.model_multiple_pcollections_flatten(contents, result_path) self.assertEqual(contents, self.get_output(result_path)) + def test_model_multiple_pcollections_flatten_with(self): + contents = ['a', 'b', 'c', 'd', 'e', 'f'] + result_path = self.create_temp_file() + snippets.model_multiple_pcollections_flatten_with(contents, result_path) + self.assertEqual(contents, self.get_output(result_path)) + def test_model_multiple_pcollections_partition(self): contents = [17, 42, 64, 32, 0, 99, 53, 89] result_path = self.create_temp_file() diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index c716c7554db4..cdf82d566a4f 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -2024,7 +2024,7 @@ playerAccuracies := ... // PCollection #### 4.2.5. Flatten {#flatten} [`Flatten`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) -[`Flatten`](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/transforms/core.py) +[`Flatten`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.Flatten) [`Flatten`](https://github.com/apache/beam/blob/master/sdks/go/pkg/beam/flatten.go) `Flatten` is a Beam transform for `PCollection` objects that store the same data type. @@ -2045,6 +2045,22 @@ PCollectionList collections = PCollectionList.of(pc1).and(pc2).and(pc3); PCollection merged = collections.apply(Flatten.pCollections()); {{< /highlight >}} +{{< paragraph class="language-java" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight java >}} +PCollection merged = pc1 + .apply(...) + // Merges the elements of pc2 in at this point... + .apply(FlattenWith.of(pc2)) + .apply(...) + // and the elements of pc3 at this point. + .apply(FlattenWith.of(pc3)) + .apply(...); +{{< /highlight >}} + {{< highlight py >}} # Flatten takes a tuple of PCollection objects. @@ -2052,6 +2068,15 @@ PCollection merged = collections.apply(Flatten.pCollections()); {{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten >}} {{< /highlight >}} +{{< paragraph class="language-py" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.FlattenWith) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight py >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with >}} +{{< /highlight >}} + {{< highlight go >}} // Flatten accepts any number of PCollections of the same element type. // Returns a single PCollection that contains all of the elements in input PCollections. From efcf1bab3d3cb3b9f295d5db22fab16897ef7f88 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Mon, 23 Sep 2024 10:23:11 -0700 Subject: [PATCH 03/82] Fix checkstyle rule. --- .../src/main/java/org/apache/beam/sdk/transforms/Flatten.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index afc11353f1a5..6d785c3bc591 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -140,7 +140,6 @@ public String getKindString() { * * @param the type of the elements in the input and output {@code PCollection}s. * @param other a PTransform whose ouptput should be flattened with the input - * @param the type of the elements in the input and output {@code PCollection}s. */ public static PTransform, PCollection> with( PTransform> other) { From 54576c3fbf570ecca986e34b81306f3a492fb2a7 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi <64488642+reeba212@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:01:38 +0530 Subject: [PATCH 04/82] Create yaml_enrichment.py --- .../apache_beam/yaml/yaml_enrichment.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 sdks/python/apache_beam/yaml/yaml_enrichment.py diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py new file mode 100644 index 000000000000..1d73e8c1794e --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -0,0 +1,55 @@ +from typing import Any, Dict +import apache_beam as beam +from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler +from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler +from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler +from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler +from apache_beam.transforms.enrichment import Enrichment +from typing import Optional + +@beam.ptransform.ptransform_fn +def enrichment_transform(pcoll, enrichment_handler: str, handler_config: Dict[str, Any], timeout: Optional[float] = 30): + """ + The Enrichment transform allows you to dynamically enhance elements in a pipeline + by performing key-value lookups against external services like APIs or databases. + + Args: + enrichment_handler: Specifies the source from where data needs to be extracted + into the pipeline for enriching data. It can be a string value in ["BigQuery", + "BigTable", "FeastFeatureStore", "VertexAIFeatureStore"]. + handler_config: Specifies the parameters for the respective enrichment_handler in a dictionary format. + BigQuery: project, table_name, row_restriction_template, fields, column_names, condition_value_fn, query_fn, min_batch_size, max_batch_size + BigTable: project_id, instance_id, table_id row_key, row_filter, app_profile_id, encoding, ow_key_fn, exception_level, include_timestamp + FeastFeatureStore: feature_store_yaml_path, feature_names, feature_service_name, full_feature_names, entity_row_fn, exception_level + VertexAIFeatureStore: project, location, api_endpoint, feature_store_name:, feature_view_name, row_key, exception_level + + Example Usage: + + - type: Enrichment + config: + enrichment_handler: 'BigTable' + handler_config: + project_id: 'apache-beam-testing' + instance_id: 'beam-test' + table_id: 'bigtable-enrichment-test' + row_key: 'product_id' + timeout: 30 + + """ + if enrichment_handler is None: + raise ValueError("Missing 'source' in enrichment spec.") + if handler_config is None: + raise ValueError("Missing 'handler_config' in enrichment spec.") + + handler_map = { + 'BigQuery': BigQueryEnrichmentHandler, + 'BigTable': BigTableEnrichmentHandler, + 'FeastFeatureStore': FeastFeatureStoreEnrichmentHandler, + 'VertexAIFeatureStore': VertexAIFeatureStoreEnrichmentHandler + } + + if enrichment_handler not in handler_map: + raise ValueError(f"Unknown enrichment source: {enrichment_handler}") + + handler = handler_map[enrichment_handler](**handler_config) + return pcoll | Enrichment(source_handler = handler, timeout = timeout) From 02a22bc2901ebec8c82455ed53e224b5a674c4df Mon Sep 17 00:00:00 2001 From: Reeba Qureshi <64488642+reeba212@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:02:26 +0530 Subject: [PATCH 05/82] Create yaml_enrichment_test.py --- .../apache_beam/yaml/yaml_enrichment_test.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 sdks/python/apache_beam/yaml/yaml_enrichment_test.py diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py new file mode 100644 index 000000000000..4042e33c3520 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -0,0 +1,62 @@ +import unittest +import logging +import mock +import apache_beam as beam +from apache_beam.testing.util import assert_that, equal_to +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.transforms import Map +from apache_beam.yaml.yaml_enrichment import enrichment_transform +from apache_beam import Row +from unittest.mock import patch +from apache_beam.yaml.yaml_transform import YamlTransform + +class FakeEnrichmentTransform: + def __init__(self, enrichment_handler, handler_config, timeout = 30): + self._enrichment_handler = enrichment_handler + self._handler_config = handler_config + self._timeout = timeout + + def __call__(self, enrichment_handler, *, handler_config, timeout = 30): + assert enrichment_handler == self._enrichment_handler + assert handler_config == self._handler_config + assert timeout == self._timeout + return beam.Map(lambda x: beam.Row(**x._asdict())) + + +class EnrichmentTransformTest(unittest.TestCase): + + @patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', FakeEnrichmentTransform) + def test_enrichment_with_bigquery(self): + input_data = [ + Row(label = "item1", rank = 0), + Row(label = "item2", rank = 1), + ] + + handler = 'BigQuery' + config = { + "project": "apache-beam-testing", + "table_name": "project.database.table", + "row_restriction_template": "label='item1' or label='item2'", + "fields": ["label"] + } + + with beam.Pipeline() as p: + with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', + FakeEnrichmentTransform( + enrichment_handler = handler, + handler_config = config)): + input_pcoll = p | 'CreateInput' >> beam.Create(input_data) + result = input_pcoll | YamlTransform( + f''' + type: Enrichment + config: + enrichment_handler: {handler} + handler_config: {config} + ''') + assert_that( + result, + equal_to(input_data)) + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() From f9fc865fbe4a06df44c461cabb8c841d279668b7 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi <64488642+reeba212@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:03:41 +0530 Subject: [PATCH 06/82] Create enrichment integration test --- .../apache_beam/yaml/tests/enrichment.yaml | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 sdks/python/apache_beam/yaml/tests/enrichment.yaml diff --git a/sdks/python/apache_beam/yaml/tests/enrichment.yaml b/sdks/python/apache_beam/yaml/tests/enrichment.yaml new file mode 100644 index 000000000000..7c107b2e71c1 --- /dev/null +++ b/sdks/python/apache_beam/yaml/tests/enrichment.yaml @@ -0,0 +1,66 @@ +fixtures: + - name: BQ_TABLE + type: "apache_beam.yaml.integration_tests.temp_bigquery_table" + config: + project: "apache-beam-testing" + - name: TEMP_DIR + type: "apache_beam.yaml.integration_tests.gcs_temp_dir" + config: + bucket: "gs://temp-storage-for-end-to-end-tests/temp-it" + +pipelines: + - pipeline: + type: chain + transforms: + - type: Create + name: Rows + config: + elements: + - {label: '11a', rank: 0} + - {label: '37a', rank: 1} + - {label: '389a', rank: 2} + + - type: WriteToBigQuery + config: + table: "{BQ_TABLE}" + + - pipeline: + type: chain + transforms: + - type: Create + name: Data + config: + elements: + - {label: '11a', name: 'S1'} + - {label: '37a', name: 'S2'} + - {label: '389a', name: 'S3'} + - type: Enrichment + name: Enriched + config: + enrichment_handler: 'BigQuery' + handler_config: + project: apache-beam-testing + table_name: "{BQ_TABLE}" + fields: ['label'] + row_restriction_template: "label = '37a'" + timeout: 30 + + - type: MapToFields + config: + language: python + fields: + label: + callable: 'lambda x: x.label' + output_type: string + rank: + callable: 'lambda x: x.rank' + output_type: integer + name: + callable: 'lambda x: x.name' + output_type: string + + - type: AssertEqual + config: + elements: + - {label: '37a', rank: 1, name: 'S2'} + From ba2b10e6b32a339dc2a29b19cb0f79ececde9121 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi <64488642+reeba212@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:07:06 +0530 Subject: [PATCH 07/82] Register enrichment transform in standard_providers.yaml --- sdks/python/apache_beam/yaml/standard_providers.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 574179805959..242faaa9a77b 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -101,3 +101,8 @@ Explode: "beam:schematransform:org.apache.beam:yaml:explode:v1" config: gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar' + +- type: 'python' + config: {} + transforms: + Enrichment: 'apache_beam.yaml.yaml_enrichment.enrichment_transform' From 788e4a20e95efd2ab9ce487f2844b0f800ac3694 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi Date: Mon, 26 Aug 2024 16:45:57 +0530 Subject: [PATCH 08/82] minor changes 1. Added links for different handlers and removed code for unreachable conditions 2. Removed patch decorator in test --- sdks/python/apache_beam/yaml/yaml_enrichment.py | 13 ++++--------- .../python/apache_beam/yaml/yaml_enrichment_test.py | 1 - 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index 1d73e8c1794e..a3d5a4f8c942 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -18,10 +18,10 @@ def enrichment_transform(pcoll, enrichment_handler: str, handler_config: Dict[st into the pipeline for enriching data. It can be a string value in ["BigQuery", "BigTable", "FeastFeatureStore", "VertexAIFeatureStore"]. handler_config: Specifies the parameters for the respective enrichment_handler in a dictionary format. - BigQuery: project, table_name, row_restriction_template, fields, column_names, condition_value_fn, query_fn, min_batch_size, max_batch_size - BigTable: project_id, instance_id, table_id row_key, row_filter, app_profile_id, encoding, ow_key_fn, exception_level, include_timestamp - FeastFeatureStore: feature_store_yaml_path, feature_names, feature_service_name, full_feature_names, entity_row_fn, exception_level - VertexAIFeatureStore: project, location, api_endpoint, feature_store_name:, feature_view_name, row_key, exception_level + BigQuery : project, table_name, row_restriction_template, fields, column_names, condition_value_fn, query_fn, min_batch_size, max_batch_size + BigTable : project_id, instance_id, table_id row_key, row_filter, app_profile_id, encoding, ow_key_fn, exception_level, include_timestamp + FeastFeatureStore : feature_store_yaml_path, feature_names, feature_service_name, full_feature_names, entity_row_fn, exception_level + VertexAIFeatureStore : project, location, api_endpoint, feature_store_name:, feature_view_name, row_key, exception_level Example Usage: @@ -36,11 +36,6 @@ def enrichment_transform(pcoll, enrichment_handler: str, handler_config: Dict[st timeout: 30 """ - if enrichment_handler is None: - raise ValueError("Missing 'source' in enrichment spec.") - if handler_config is None: - raise ValueError("Missing 'handler_config' in enrichment spec.") - handler_map = { 'BigQuery': BigQueryEnrichmentHandler, 'BigTable': BigTableEnrichmentHandler, diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py index 4042e33c3520..35e333fc2f41 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -25,7 +25,6 @@ def __call__(self, enrichment_handler, *, handler_config, timeout = 30): class EnrichmentTransformTest(unittest.TestCase): - @patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', FakeEnrichmentTransform) def test_enrichment_with_bigquery(self): input_data = [ Row(label = "item1", rank = 0), From e20166d87c0369ee7cb243bea59d5bc84daa3f22 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi Date: Thu, 19 Sep 2024 21:48:32 +0530 Subject: [PATCH 09/82] minor updates --- .../apache_beam/yaml/integration_tests.py | 2 +- .../apache_beam/yaml/tests/enrichment.yaml | 17 ++++ .../apache_beam/yaml/yaml_enrichment.py | 97 ++++++++++++++----- .../apache_beam/yaml/yaml_enrichment_test.py | 94 ++++++++++-------- 4 files changed, 144 insertions(+), 66 deletions(-) diff --git a/sdks/python/apache_beam/yaml/integration_tests.py b/sdks/python/apache_beam/yaml/integration_tests.py index af1be7b1e8e5..72b3918195da 100644 --- a/sdks/python/apache_beam/yaml/integration_tests.py +++ b/sdks/python/apache_beam/yaml/integration_tests.py @@ -69,7 +69,7 @@ def temp_bigquery_table(project, prefix='yaml_bq_it_'): dataset_id = '%s_%s' % (prefix, uuid.uuid4().hex) bigquery_client.get_or_create_dataset(project, dataset_id) logging.info("Created dataset %s in project %s", dataset_id, project) - yield f'{project}:{dataset_id}.tmp_table' + yield f'{project}.{dataset_id}.tmp_table' request = bigquery.BigqueryDatasetsDeleteRequest( projectId=project, datasetId=dataset_id, deleteContents=True) logging.info("Deleting dataset %s in project %s", dataset_id, project) diff --git a/sdks/python/apache_beam/yaml/tests/enrichment.yaml b/sdks/python/apache_beam/yaml/tests/enrichment.yaml index 7c107b2e71c1..216a18add83f 100644 --- a/sdks/python/apache_beam/yaml/tests/enrichment.yaml +++ b/sdks/python/apache_beam/yaml/tests/enrichment.yaml @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + fixtures: - name: BQ_TABLE type: "apache_beam.yaml.integration_tests.temp_bigquery_table" diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index a3d5a4f8c942..77428bbd59f5 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from typing import Any, Dict import apache_beam as beam from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler @@ -7,21 +24,53 @@ from apache_beam.transforms.enrichment import Enrichment from typing import Optional + @beam.ptransform.ptransform_fn -def enrichment_transform(pcoll, enrichment_handler: str, handler_config: Dict[str, Any], timeout: Optional[float] = 30): - """ - The Enrichment transform allows you to dynamically enhance elements in a pipeline - by performing key-value lookups against external services like APIs or databases. +def enrichment_transform( + pcoll, + enrichment_handler: str, + handler_config: Dict[str, Any], + timeout: Optional[float] = 30): + """ + The Enrichment transform allows you to dynamically + enhance elements in a pipeline by performing key-value + lookups against external services like APIs or databases. Args: - enrichment_handler: Specifies the source from where data needs to be extracted - into the pipeline for enriching data. It can be a string value in ["BigQuery", - "BigTable", "FeastFeatureStore", "VertexAIFeatureStore"]. - handler_config: Specifies the parameters for the respective enrichment_handler in a dictionary format. - BigQuery : project, table_name, row_restriction_template, fields, column_names, condition_value_fn, query_fn, min_batch_size, max_batch_size - BigTable : project_id, instance_id, table_id row_key, row_filter, app_profile_id, encoding, ow_key_fn, exception_level, include_timestamp - FeastFeatureStore : feature_store_yaml_path, feature_names, feature_service_name, full_feature_names, entity_row_fn, exception_level - VertexAIFeatureStore : project, location, api_endpoint, feature_store_name:, feature_view_name, row_key, exception_level + enrichment_handler: Specifies the source from + where data needs to be extracted + into the pipeline for enriching data. + It can be a string value in ["BigQuery", + "BigTable", "FeastFeatureStore", + "VertexAIFeatureStore"]. + handler_config: Specifies the parameters for + the respective enrichment_handler in a dictionary format. + BigQuery = ( + "BigQuery: " + "project, table_name, row_restriction_template, " + "fields, column_names, "condition_value_fn, " + "query_fn, min_batch_size, max_batch_size" + ) + + BigTable = ( + "BigTable: " + "project_id, instance_id, table_id, " + "row_key, row_filter, app_profile_id, " + "encoding, ow_key_fn, exception_level, include_timestamp" + ) + + FeastFeatureStore = ( + "FeastFeatureStore: " + "feature_store_yaml_path, feature_names, " + "feature_service_name, full_feature_names, " + "entity_row_fn, exception_level" + ) + + VertexAIFeatureStore = ( + "VertexAIFeatureStore: " + "project, location, api_endpoint, feature_store_name, " + "feature_view_name, row_key, exception_level" + ) Example Usage: @@ -36,15 +85,15 @@ def enrichment_transform(pcoll, enrichment_handler: str, handler_config: Dict[st timeout: 30 """ - handler_map = { - 'BigQuery': BigQueryEnrichmentHandler, - 'BigTable': BigTableEnrichmentHandler, - 'FeastFeatureStore': FeastFeatureStoreEnrichmentHandler, - 'VertexAIFeatureStore': VertexAIFeatureStoreEnrichmentHandler - } - - if enrichment_handler not in handler_map: - raise ValueError(f"Unknown enrichment source: {enrichment_handler}") - - handler = handler_map[enrichment_handler](**handler_config) - return pcoll | Enrichment(source_handler = handler, timeout = timeout) + handler_map = { + 'BigQuery': BigQueryEnrichmentHandler, + 'BigTable': BigTableEnrichmentHandler, + 'FeastFeatureStore': FeastFeatureStoreEnrichmentHandler, + 'VertexAIFeatureStore': VertexAIFeatureStoreEnrichmentHandler + } + + if enrichment_handler not in handler_map: + raise ValueError(f"Unknown enrichment source: {enrichment_handler}") + + handler = handler_map[enrichment_handler](**handler_config) + return pcoll | Enrichment(source_handler=handler, timeout=timeout) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py index 35e333fc2f41..9cd28995dfe4 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -1,61 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import unittest import logging import mock import apache_beam as beam -from apache_beam.testing.util import assert_that, equal_to -from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.transforms import Map -from apache_beam.yaml.yaml_enrichment import enrichment_transform +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to from apache_beam import Row -from unittest.mock import patch from apache_beam.yaml.yaml_transform import YamlTransform + class FakeEnrichmentTransform: - def __init__(self, enrichment_handler, handler_config, timeout = 30): - self._enrichment_handler = enrichment_handler - self._handler_config = handler_config - self._timeout = timeout + def __init__(self, enrichment_handler, handler_config, timeout=30): + self._enrichment_handler = enrichment_handler + self._handler_config = handler_config + self._timeout = timeout - def __call__(self, enrichment_handler, *, handler_config, timeout = 30): - assert enrichment_handler == self._enrichment_handler - assert handler_config == self._handler_config - assert timeout == self._timeout - return beam.Map(lambda x: beam.Row(**x._asdict())) + def __call__(self, enrichment_handler, *, handler_config, timeout=30): + assert enrichment_handler == self._enrichment_handler + assert handler_config == self._handler_config + assert timeout == self._timeout + return beam.Map(lambda x: beam.Row(**x._asdict())) class EnrichmentTransformTest(unittest.TestCase): + def test_enrichment_with_bigquery(self): + input_data = [ + Row(label="item1", rank=0), + Row(label="item2", rank=1), + ] - def test_enrichment_with_bigquery(self): - input_data = [ - Row(label = "item1", rank = 0), - Row(label = "item2", rank = 1), - ] - - handler = 'BigQuery' - config = { - "project": "apache-beam-testing", - "table_name": "project.database.table", - "row_restriction_template": "label='item1' or label='item2'", - "fields": ["label"] - } - - with beam.Pipeline() as p: - with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', - FakeEnrichmentTransform( - enrichment_handler = handler, - handler_config = config)): - input_pcoll = p | 'CreateInput' >> beam.Create(input_data) - result = input_pcoll | YamlTransform( - f''' + handler = 'BigQuery' + config = { + "project": "apache-beam-testing", + "table_name": "project.database.table", + "row_restriction_template": "label='item1' or label='item2'", + "fields": ["label"] + } + + with beam.Pipeline() as p: + with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', + FakeEnrichmentTransform(enrichment_handler=handler, + handler_config=config)): + input_pcoll = p | 'CreateInput' >> beam.Create(input_data) + result = input_pcoll | YamlTransform( + f''' type: Enrichment config: enrichment_handler: {handler} handler_config: {config} ''') - assert_that( - result, - equal_to(input_data)) + assert_that(result, equal_to(input_data)) + if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - unittest.main() + logging.getLogger().setLevel(logging.INFO) + unittest.main() From ee6258ef5abfbf110067497975c509033e547a97 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi Date: Fri, 20 Sep 2024 00:36:57 +0530 Subject: [PATCH 10/82] fixing lint failures --- sdks/python/apache_beam/yaml/yaml_enrichment.py | 8 +++++--- sdks/python/apache_beam/yaml/yaml_enrichment_test.py | 6 ++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index 77428bbd59f5..4ec8a5a786d3 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -15,14 +15,16 @@ # limitations under the License. # -from typing import Any, Dict +from typing import Any +from typing import Dict +from typing import Optional + import apache_beam as beam +from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler -from apache_beam.transforms.enrichment import Enrichment -from typing import Optional @beam.ptransform.ptransform_fn diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py index 9cd28995dfe4..e26d6140af23 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -15,13 +15,15 @@ # limitations under the License. # -import unittest import logging +import unittest + import mock + import apache_beam as beam +from apache_beam import Row from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam import Row from apache_beam.yaml.yaml_transform import YamlTransform From c647e47c573e884196c4919027ad59e4575af429 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi Date: Fri, 20 Sep 2024 07:39:31 +0530 Subject: [PATCH 11/82] disable feast if not installed --- sdks/python/apache_beam/yaml/yaml_enrichment.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index 4ec8a5a786d3..0fbe57321395 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -20,12 +20,17 @@ from typing import Optional import apache_beam as beam +from apache_beam.yaml import options from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler -from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler +try: + from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler +except ImportError: + FeastFeatureStoreEnrichmentHandler = None + @beam.ptransform.ptransform_fn def enrichment_transform( @@ -87,6 +92,12 @@ def enrichment_transform( timeout: 30 """ + options.YamlOptions.check_enabled(pcoll.pipeline, 'Enrichment') + if (enrichment_handler == 'FeastFeatureStore' and + not FeastFeatureStoreEnrichmentHandler): + raise ValueError( + "FeastFeatureStore handler requires 'feast' package to be installed. " + + "Please install using 'pip install feast[gcp]' and try again.") handler_map = { 'BigQuery': BigQueryEnrichmentHandler, 'BigTable': BigTableEnrichmentHandler, From e917933c09005d6d37d4ad9c0116a6688322258b Mon Sep 17 00:00:00 2001 From: Reeba Qureshi Date: Thu, 3 Oct 2024 23:36:34 +0530 Subject: [PATCH 12/82] fix failures --- .../apache_beam/yaml/standard_providers.yaml | 3 ++- .../apache_beam/yaml/tests/enrichment.yaml | 3 ++- .../apache_beam/yaml/yaml_enrichment.py | 21 ++++++++++++++----- sdks/python/apache_beam/yaml/yaml_provider.py | 16 ++++++++++---- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 242faaa9a77b..15d5fdc24914 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -103,6 +103,7 @@ gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar' - type: 'python' - config: {} + config: + requires_gcp: true transforms: Enrichment: 'apache_beam.yaml.yaml_enrichment.enrichment_transform' diff --git a/sdks/python/apache_beam/yaml/tests/enrichment.yaml b/sdks/python/apache_beam/yaml/tests/enrichment.yaml index 216a18add83f..6469c094b8b4 100644 --- a/sdks/python/apache_beam/yaml/tests/enrichment.yaml +++ b/sdks/python/apache_beam/yaml/tests/enrichment.yaml @@ -80,4 +80,5 @@ pipelines: config: elements: - {label: '37a', rank: 1, name: 'S2'} - + options: + yaml_experimental_features: [ 'Enrichment' ] \ No newline at end of file diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index 0fbe57321395..e2dc72f3dac8 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -21,15 +21,19 @@ import apache_beam as beam from apache_beam.yaml import options -from apache_beam.transforms.enrichment import Enrichment -from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler -from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler -from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler try: + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler except ImportError: - FeastFeatureStoreEnrichmentHandler = None + Enrichment = None # type: ignore + BigQueryEnrichmentHandler = None # type: ignore + BigTableEnrichmentHandler = None # type: ignore + VertexAIFeatureStoreEnrichmentHandler = None # type: ignore + FeastFeatureStoreEnrichmentHandler = None # type: ignore @beam.ptransform.ptransform_fn @@ -93,11 +97,18 @@ def enrichment_transform( """ options.YamlOptions.check_enabled(pcoll.pipeline, 'Enrichment') + + if not Enrichment: + raise ValueError( + f"gcp dependencies not installed. Cannot use {enrichment_handler} " + f"handler. Please install using 'pip install apache-beam[gcp]'.") + if (enrichment_handler == 'FeastFeatureStore' and not FeastFeatureStoreEnrichmentHandler): raise ValueError( "FeastFeatureStore handler requires 'feast' package to be installed. " + "Please install using 'pip install feast[gcp]' and try again.") + handler_map = { 'BigQuery': BigQueryEnrichmentHandler, 'BigTable': BigTableEnrichmentHandler, diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index c2cba936abce..2d6ed2e5b956 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -231,9 +231,17 @@ def provider_from_spec(cls, spec): result.to_json = lambda: spec return result except Exception as exn: - raise ValueError( - f'Unable to instantiate provider of type {type} ' - f'at line {SafeLineLoader.get_line(spec)}: {exn}') from exn + if isinstance(exn, ModuleNotFoundError) and config.get('requires_gcp', + False): + print( + f"gcp dependencies not installed. Cannot use transforms: " + f"{', '.join(urns.keys())}. Please install using " + f"'pip install apache-beam[gcp]'.") + return InlineProvider({}) + else: + raise ValueError( + f'Unable to instantiate provider of type {type} ' + f'at line {SafeLineLoader.get_line(spec)}: {exn}') from exn else: raise NotImplementedError( f'Unknown provider type: {type} ' @@ -335,7 +343,7 @@ def cache_artifacts(self): @ExternalProvider.register_provider_type('python') -def python(urns, packages=()): +def python(urns, packages=(), requires_gcp=False): if packages: return ExternalPythonProvider(urns, packages) else: From e24ea06263f20d3425673e96e245d0527ec83743 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi Date: Fri, 4 Oct 2024 00:31:41 +0530 Subject: [PATCH 13/82] separate block for feast import error --- sdks/python/apache_beam/yaml/yaml_enrichment.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index e2dc72f3dac8..00f2a5c1b1d1 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -27,12 +27,15 @@ from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler - from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler except ImportError: Enrichment = None # type: ignore BigQueryEnrichmentHandler = None # type: ignore BigTableEnrichmentHandler = None # type: ignore VertexAIFeatureStoreEnrichmentHandler = None # type: ignore + +try: + from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler +except ImportError: FeastFeatureStoreEnrichmentHandler = None # type: ignore From 19044ec97729f1328577ac4930d79a517e713741 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 3 Oct 2024 15:46:21 -0400 Subject: [PATCH 14/82] Touch Flink trigger files for testing addition of Flink 1.19 support --- .../setup-default-test-properties/test-properties.json | 2 +- .github/trigger_files/beam_PostCommit_Go_VR_Flink.json | 3 ++- .../beam_PostCommit_Java_Examples_Flink.json | 3 +++ .../beam_PostCommit_Java_Jpms_Flink_Java11.json | 3 +++ .../beam_PostCommit_Java_ValidatesRunner_Flink.json | 3 ++- ...beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json | 3 ++- .../beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json | 4 ++++ .../beam_PostCommit_Python_ValidatesRunner_Flink.json | 3 +++ .github/trigger_files/beam_PostCommit_XVR_Flink.json | 3 +++ sdks/go/examples/wasm/README.md | 2 +- sdks/python/apache_beam/options/pipeline_options.py | 2 +- sdks/typescript/src/apache_beam/runners/flink.ts | 2 +- .../www/site/content/en/documentation/runners/flink.md | 8 +++++++- 13 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 .github/trigger_files/beam_PostCommit_Java_Examples_Flink.json create mode 100644 .github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json create mode 100644 .github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json create mode 100644 .github/trigger_files/beam_PostCommit_XVR_Flink.json diff --git a/.github/actions/setup-default-test-properties/test-properties.json b/.github/actions/setup-default-test-properties/test-properties.json index 098e4ca1935c..efe66de8ee1e 100644 --- a/.github/actions/setup-default-test-properties/test-properties.json +++ b/.github/actions/setup-default-test-properties/test-properties.json @@ -14,7 +14,7 @@ }, "JavaTestProperties": { "SUPPORTED_VERSIONS": ["8", "11", "17", "21"], - "FLINK_VERSIONS": ["1.15", "1.16", "1.17", "1.18"], + "FLINK_VERSIONS": ["1.15", "1.16", "1.17", "1.18", "1.19"], "SPARK_VERSIONS": ["2", "3"] }, "GoTestProperties": { diff --git a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json index e3d6056a5de9..b98aece75634 100644 --- a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 1, + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json new file mode 100644 index 000000000000..dd9afb90e638 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json b/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json new file mode 100644 index 000000000000..dd9afb90e638 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json index b970762c8397..9200c368abbe 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json index b970762c8397..9200c368abbe 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json new file mode 100644 index 000000000000..b07a3c47e196 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json @@ -0,0 +1,4 @@ +{ + + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json index e69de29bb2d1..0b34d452d42c 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_XVR_Flink.json b/.github/trigger_files/beam_PostCommit_XVR_Flink.json new file mode 100644 index 000000000000..0b34d452d42c --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" +} diff --git a/sdks/go/examples/wasm/README.md b/sdks/go/examples/wasm/README.md index 84d30a3c6a63..a78649134305 100644 --- a/sdks/go/examples/wasm/README.md +++ b/sdks/go/examples/wasm/README.md @@ -68,7 +68,7 @@ cd $BEAM_HOME Expected output should include the following, from which you acquire the latest flink runner version. ```shell -'flink_versions: 1.15,1.16,1.17,1.18' +'flink_versions: 1.15,1.16,1.17,1.18,1.19' ``` #### 2. Set to the latest flink runner version i.e. 1.16 diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 4497ab0993a4..837dc0f5439f 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1679,7 +1679,7 @@ def _add_argparse_args(cls, parser): class FlinkRunnerOptions(PipelineOptions): # These should stay in sync with gradle.properties. - PUBLISHED_FLINK_VERSIONS = ['1.15', '1.16', '1.17', '1.18'] + PUBLISHED_FLINK_VERSIONS = ['1.15', '1.16', '1.17', '1.18', '1.19'] @classmethod def _add_argparse_args(cls, parser): diff --git a/sdks/typescript/src/apache_beam/runners/flink.ts b/sdks/typescript/src/apache_beam/runners/flink.ts index ad4339b431f5..e21876c0d517 100644 --- a/sdks/typescript/src/apache_beam/runners/flink.ts +++ b/sdks/typescript/src/apache_beam/runners/flink.ts @@ -28,7 +28,7 @@ import { JavaJarService } from "../utils/service"; const MAGIC_HOST_NAMES = ["[local]", "[auto]"]; // These should stay in sync with gradle.properties. -const PUBLISHED_FLINK_VERSIONS = ["1.15", "1.16", "1.17", "1.18"]; +const PUBLISHED_FLINK_VERSIONS = ["1.15", "1.16", "1.17", "1.18", "1.19"]; const defaultOptions = { flinkMaster: "[local]", diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 7325c480955c..2c28aa7062ec 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -93,7 +93,7 @@ from the [compatibility table](#flink-version-compatibility) below. For example: {{< highlight java >}} org.apache.beam - beam-runners-flink-1.17 + beam-runners-flink-1.19 {{< param release_latest >}} {{< /highlight >}} @@ -200,6 +200,7 @@ Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are availab [Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). +[Flink 1.19](https://hub.docker.com/r/apache/beam_flink1.19_job_server). {{< /paragraph >}} @@ -326,6 +327,11 @@ To find out which version of Flink is compatible with Beam please see the table Artifact Id Supported Beam Versions + + 1.19.x + beam-runners-flink-1.19 + ≥ 2.61.0 + 1.18.x beam-runners-flink-1.18 From c288f1ad38410b266a8bf6ed5b54dc2d55935de3 Mon Sep 17 00:00:00 2001 From: Reeba Qureshi Date: Fri, 4 Oct 2024 20:01:56 +0530 Subject: [PATCH 15/82] minor changes --- .../apache_beam/yaml/standard_providers.yaml | 3 +-- sdks/python/apache_beam/yaml/yaml_provider.py | 16 ++++------------ 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 15d5fdc24914..242faaa9a77b 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -103,7 +103,6 @@ gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar' - type: 'python' - config: - requires_gcp: true + config: {} transforms: Enrichment: 'apache_beam.yaml.yaml_enrichment.enrichment_transform' diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 2d6ed2e5b956..c2cba936abce 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -231,17 +231,9 @@ def provider_from_spec(cls, spec): result.to_json = lambda: spec return result except Exception as exn: - if isinstance(exn, ModuleNotFoundError) and config.get('requires_gcp', - False): - print( - f"gcp dependencies not installed. Cannot use transforms: " - f"{', '.join(urns.keys())}. Please install using " - f"'pip install apache-beam[gcp]'.") - return InlineProvider({}) - else: - raise ValueError( - f'Unable to instantiate provider of type {type} ' - f'at line {SafeLineLoader.get_line(spec)}: {exn}') from exn + raise ValueError( + f'Unable to instantiate provider of type {type} ' + f'at line {SafeLineLoader.get_line(spec)}: {exn}') from exn else: raise NotImplementedError( f'Unknown provider type: {type} ' @@ -343,7 +335,7 @@ def cache_artifacts(self): @ExternalProvider.register_provider_type('python') -def python(urns, packages=(), requires_gcp=False): +def python(urns, packages=()): if packages: return ExternalPythonProvider(urns, packages) else: From b94f8a74b5a9c2a5216320342f038b03879dd289 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Mon, 14 Oct 2024 11:03:51 -0400 Subject: [PATCH 16/82] Pin protobuf version for TF dependency tests (#32719) * Pin protobuf version for TF dependency tests * pin specfic protobuf version * increment tested TF version * pin at higher version of protobuf * fix incorrect configuration * try bumping tf version * further specify dependency versions * try adding other ml testing deps for compat * fix transformers tests * tweak deps for transformers * whitespace --- sdks/python/tox.ini | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 8cc125d946a8..8cdc4a98bbfe 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -394,10 +394,12 @@ commands = [testenv:py39-tensorflow-212] deps = - 212: tensorflow>=2.12rc1,<2.13 - # Help pip resolve conflict with typing-extensions for old version of TF https://github.com/apache/beam/issues/30852 - 212: pydantic<2.7 -extras = test,gcp + 212: + tensorflow>=2.12rc1,<2.13 + # Help pip resolve conflict with typing-extensions for old version of TF https://github.com/apache/beam/issues/30852 + pydantic<2.7 + protobuf==4.25.5 +extras = test,gcp,ml_test commands = # Log tensorflow version for debugging /bin/sh -c "pip freeze | grep -E tensorflow" @@ -428,7 +430,8 @@ deps = 430: transformers>=4.30.0,<4.31.0 torch>=1.9.0,<1.14.0 tensorflow==2.12.0 -extras = test,gcp + protobuf==4.25.5 +extras = test,gcp,ml_test commands = # Log transformers and its dependencies version for debugging /bin/sh -c "pip freeze | grep -E transformers" From 1bc848b4b1ebfe4f3576b219d2d667f9a5cf8cbe Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 14 Oct 2024 11:35:18 -0400 Subject: [PATCH 17/82] Usse ubuntu-22.04 for release candidate build for now (#32767) --- .github/workflows/build_release_candidate.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_release_candidate.yml b/.github/workflows/build_release_candidate.yml index fbb0ca22f333..a70091726c1b 100644 --- a/.github/workflows/build_release_candidate.yml +++ b/.github/workflows/build_release_candidate.yml @@ -97,7 +97,7 @@ jobs: stage_java_source: if: ${{ fromJson(github.event.inputs.STAGE).java_source == 'yes'}} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Mask Apache Password run: | @@ -161,7 +161,7 @@ jobs: stage_python_artifacts: if: ${{ fromJson(github.event.inputs.STAGE).python_artifacts == 'yes'}} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v4 @@ -246,7 +246,7 @@ jobs: stage_docker: if: ${{ fromJson(github.event.inputs.STAGE).docker_artifacts == 'yes'}} # Note: if this ever changes to self-hosted, remove the "Remove default github maven configuration" step - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v4 @@ -282,7 +282,7 @@ jobs: beam_site_pr: if: ${{ fromJson(github.event.inputs.STAGE).beam_site_pr == 'yes'}} # Note: if this ever changes to self-hosted, remove the "Remove default github maven configuration" step - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: RC_TAG: "v${{ github.event.inputs.RELEASE }}-RC${{ github.event.inputs.RC }}" BRANCH_NAME: updates_release_${{ github.event.inputs.RELEASE }} @@ -402,7 +402,7 @@ jobs: build_and_stage_prism: if: ${{ fromJson(github.event.inputs.STAGE).prism == 'yes'}} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v4 From 1f575d4d816032e136b44b15ea940a19c67a9466 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 14 Oct 2024 11:38:19 -0400 Subject: [PATCH 18/82] Avoid repeated run of setDefaultPipelineOptionsOnce in TestPipelineOptions.create (#32723) --- .../org/apache/beam/sdk/io/FileSystems.java | 23 +++++++++++++++++-- .../apache/beam/sdk/testing/TestPipeline.java | 4 ++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java index fb25cac6262f..5ca22749b163 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java @@ -565,11 +565,13 @@ static FileSystem getFileSystemInternal(String scheme) { * *

It will be used in {@link FileSystemRegistrar FileSystemRegistrars} for all schemes. * - *

This is expected only to be used by runners after {@code Pipeline.run}, or in tests. + *

Outside of workers where Beam FileSystem API is used (e.g. test methods, user code executed + * during pipeline submission), consider use {@link #registerFileSystemsOnce} if initialize + * FIleSystem of supported schema is the main goal. */ @Internal public static void setDefaultPipelineOptions(PipelineOptions options) { - checkNotNull(options, "options"); + checkNotNull(options, "options cannot be null"); long id = options.getOptionsId(); int nextRevision = options.revision(); @@ -593,6 +595,23 @@ public static void setDefaultPipelineOptions(PipelineOptions options) { } } + /** + * Register file systems once if never done before. + * + *

This method executes {@link #setDefaultPipelineOptions} only if it has never been run, + * otherwise it returns immediately. + * + *

It is internally used by test setup to avoid repeated filesystem registrations (involves + * expensive ServiceLoader calls) when there are multiple pipeline and PipelineOptions object + * initialized, which is commonly seen in test execution. + */ + @Internal + public static synchronized void registerFileSystemsOnce(PipelineOptions options) { + if (FILESYSTEM_REVISION.get() == null) { + setDefaultPipelineOptions(options); + } + } + @VisibleForTesting static Map verifySchemesAreUnique( PipelineOptions options, Set registrars) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java index ed61f7f3d6f2..328bf19c466c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java @@ -373,7 +373,7 @@ public PipelineResult runWithAdditionalOptionArgs(List additionalArgs) { } newOptions.setStableUniqueNames(CheckEnabled.ERROR); - FileSystems.setDefaultPipelineOptions(options); + FileSystems.registerFileSystemsOnce(options); return run(newOptions); } catch (IOException e) { throw new RuntimeException( @@ -515,7 +515,7 @@ public static PipelineOptions testingPipelineOptions() { } options.setStableUniqueNames(CheckEnabled.ERROR); - FileSystems.setDefaultPipelineOptions(options); + FileSystems.registerFileSystemsOnce(options); return options; } catch (IOException e) { throw new RuntimeException( From c04e91d89d5ad84b20043e3f3fbfe8d5edfac6e9 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 14 Oct 2024 15:43:19 -0400 Subject: [PATCH 19/82] Fix download artifact truncate page (#32772) --- .../src/main/scripts/download_github_actions_artifacts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/release/src/main/scripts/download_github_actions_artifacts.py b/release/src/main/scripts/download_github_actions_artifacts.py index 99526f1ac7d1..5c553efc6500 100644 --- a/release/src/main/scripts/download_github_actions_artifacts.py +++ b/release/src/main/scripts/download_github_actions_artifacts.py @@ -279,8 +279,11 @@ def fetch_github_artifacts(run_id, repo_url, artifacts_dir, github_token, rc_num print("Starting downloading artifacts ... (it may take a while)") run_data = get_single_workflow_run_data(run_id, repo_url, github_token) artifacts_url = safe_get(run_data, "artifacts_url") - data_artifacts = request_url(artifacts_url, github_token) + data_artifacts = request_url(artifacts_url + '?per_page=100', github_token) artifacts = safe_get(data_artifacts, "artifacts", artifacts_url) + total_count = safe_get(data_artifacts, "total_count", artifacts_url) + if int(total_count) != len(artifacts): + raise RuntimeError(f"Expected total count {total_count} different than returned list length {len(data_artifacts)}") print('Filtering ', len(artifacts), ' artifacts') filtered_artifacts = filter_artifacts(artifacts, rc_number) print('Preparing to download ', len(filtered_artifacts), ' artifacts') From f3708e06e130dc40baafc96f5e8786dc697ddb3c Mon Sep 17 00:00:00 2001 From: liferoad Date: Mon, 14 Oct 2024 17:06:03 -0400 Subject: [PATCH 20/82] Update pyproject.toml --- sdks/python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 4eb827297019..d34afcf70a5e 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -21,7 +21,7 @@ requires = [ "setuptools", "wheel>=0.36.0", - "grpcio-tools==1.62.1", + "grpcio-tools>=1.62.1", "mypy-protobuf==3.5.0", # Avoid https://github.com/pypa/virtualenv/issues/2006 "distlib==0.3.7", From e47c261502d722140b6bcb0e275bed4cb7b5624f Mon Sep 17 00:00:00 2001 From: liferoad Date: Mon, 14 Oct 2024 19:23:26 -0400 Subject: [PATCH 21/82] use 1.65.5 --- sdks/python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index d34afcf70a5e..a99599a2ce2b 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -21,7 +21,7 @@ requires = [ "setuptools", "wheel>=0.36.0", - "grpcio-tools>=1.62.1", + "grpcio-tools==1.65.5", "mypy-protobuf==3.5.0", # Avoid https://github.com/pypa/virtualenv/issues/2006 "distlib==0.3.7", From 9a2d456f72b07ac449cd24c6195179a3b1b8be65 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:10:31 +0300 Subject: [PATCH 22/82] [Python] Managed Transforms API (#31495) * managed module * clean up * lint * try with real example * cleanup * add documentation * fix doc * add pyyaml dependency * cleanup * return deps * return deps * fix doc * address some comments * doc updates * define managed transform URNs in proto * fix URN * remove managed dependency * add managed iceberg integration test * lint * lint * dependency fix * lint * dependency fix * dependency fix * lint * lint * dependency fix * rename test file --- ...m_PostCommit_Python_Xlang_IO_Dataflow.json | 2 +- ...eam_PostCommit_Python_Xlang_IO_Direct.json | 4 + ...beam_PostCommit_Python_Xlang_IO_Direct.yml | 96 +++++++++ CHANGES.md | 2 + .../pipeline/v1/external_transforms.proto | 14 ++ sdks/java/io/expansion-service/build.gradle | 2 + sdks/java/io/iceberg/build.gradle | 3 +- .../apache/beam/sdk/io/iceberg/IcebergIO.java | 5 +- .../IcebergReadSchemaTransformProvider.java | 6 +- .../IcebergWriteSchemaTransformProvider.java | 5 +- sdks/java/io/kafka/build.gradle | 1 + .../KafkaReadSchemaTransformProvider.java | 4 +- .../KafkaWriteSchemaTransformProvider.java | 5 +- sdks/java/managed/build.gradle | 1 + .../org/apache/beam/sdk/managed/Managed.java | 11 +- .../managed/ManagedTransformConstants.java | 13 +- .../apache_beam/portability/common_urns.py | 1 + .../python/apache_beam/transforms/__init__.py | 1 + sdks/python/apache_beam/transforms/managed.py | 182 ++++++++++++++++++ .../transforms/managed_iceberg_it_test.py | 70 +++++++ sdks/python/setup.py | 3 +- sdks/standard_expansion_services.yaml | 2 +- 22 files changed, 407 insertions(+), 26 deletions(-) create mode 100644 .github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json create mode 100644 .github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml create mode 100644 sdks/python/apache_beam/transforms/managed.py create mode 100644 sdks/python/apache_beam/transforms/managed_iceberg_it_test.py diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json new file mode 100644 index 000000000000..e3d6056a5de9 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 +} diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml new file mode 100644 index 000000000000..5092a1981154 --- /dev/null +++ b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PostCommit Python Xlang IO Direct + +on: + schedule: + - cron: '30 5/6 * * *' + pull_request_target: + paths: ['release/trigger_all_tests.json', '.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json'] + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PostCommit_Python_Xlang_IO_Direct: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event.comment.body == 'Run Python_Xlang_IO_Direct PostCommit' + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 100 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_PostCommit_Python_Xlang_IO_Direct"] + job_phrase: ["Run Python_Xlang_IO_Direct PostCommit"] + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + python-version: | + 3.9 + 3.12 + - name: run PostCommit Python Xlang IO Direct script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:test-suites:direct:ioCrossLanguagePostCommit + arguments: -PuseWheelDistribution + - name: Archive Python Test Results + uses: actions/upload-artifact@v4 + if: failure() + with: + name: Python Test Results + path: '**/pytest*.xml' + - name: Publish Python Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + commit: '${{ env.prsha || env.GITHUB_SHA }}' + comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} + files: '**/pytest*.xml' \ No newline at end of file diff --git a/CHANGES.md b/CHANGES.md index cc1268635046..4e21e400e60d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -59,11 +59,13 @@ * New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). * New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). +* [Python] Introduce Managed Transforms API ([#31495](https://github.com/apache/beam/pull/31495)) ## I/Os * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * [Managed Iceberg] Support creating tables if needed ([#32686](https://github.com/apache/beam/pull/32686)) +* [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) * [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) ## New Features / Improvements diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index 429371e11055..b03350966d6c 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -59,6 +59,20 @@ message ExpansionMethods { } } +// Defines the URNs for managed transforms. +message ManagedTransforms { + enum Urns { + ICEBERG_READ = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:iceberg_read:v1"]; + ICEBERG_WRITE = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:iceberg_write:v1"]; + KAFKA_READ = 2 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:kafka_read:v1"]; + KAFKA_WRITE = 3 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:kafka_write:v1"]; + } +} + // A configuration payload for an external transform. // Used to define a Java transform that can be directly instantiated by a Java // expansion service. diff --git a/sdks/java/io/expansion-service/build.gradle b/sdks/java/io/expansion-service/build.gradle index b09a92ca315c..cc8eccf98997 100644 --- a/sdks/java/io/expansion-service/build.gradle +++ b/sdks/java/io/expansion-service/build.gradle @@ -44,6 +44,8 @@ ext.summary = "Expansion service serving several Java IOs" dependencies { implementation project(":sdks:java:expansion-service") permitUnusedDeclared project(":sdks:java:expansion-service") // BEAM-11761 + implementation project(":sdks:java:managed") + permitUnusedDeclared project(":sdks:java:managed") // BEAM-11761 implementation project(":sdks:java:io:iceberg") permitUnusedDeclared project(":sdks:java:io:iceberg") // BEAM-11761 implementation project(":sdks:java:io:kafka") diff --git a/sdks/java/io/iceberg/build.gradle b/sdks/java/io/iceberg/build.gradle index 3d653d6b276e..e10c6f38e20f 100644 --- a/sdks/java/io/iceberg/build.gradle +++ b/sdks/java/io/iceberg/build.gradle @@ -44,7 +44,7 @@ def orc_version = "1.9.2" dependencies { implementation library.java.vendored_guava_32_1_2_jre implementation project(path: ":sdks:java:core", configuration: "shadow") - implementation project(":sdks:java:managed") + implementation project(path: ":model:pipeline", configuration: "shadow") implementation library.java.slf4j_api implementation library.java.joda_time implementation "org.apache.parquet:parquet-column:$parquet_version" @@ -55,6 +55,7 @@ dependencies { implementation "org.apache.iceberg:iceberg-orc:$iceberg_version" implementation library.java.hadoop_common + testImplementation project(":sdks:java:managed") testImplementation library.java.hadoop_client testImplementation library.java.bigdataoss_gcsio testImplementation library.java.bigdataoss_gcs_connector diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java index fa4ff9714c7f..1d4b36585237 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java @@ -24,7 +24,6 @@ import java.util.List; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.io.Read; -import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PBegin; @@ -45,8 +44,8 @@ * A connector that reads and writes to Apache Iceberg * tables. * - *

{@link IcebergIO} is offered as a {@link Managed} transform. This class is subject to change - * and should not be used directly. Instead, use it via {@link Managed#ICEBERG} like so: + *

{@link IcebergIO} is offered as a Managed transform. This class is subject to change and + * should not be used directly. Instead, use it like so: * *

{@code
  * Map config = Map.of(
diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java
index df7bda4560dd..d44149fda08e 100644
--- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java
+++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java
@@ -17,10 +17,12 @@
  */
 package org.apache.beam.sdk.io.iceberg;
 
+import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
+
 import com.google.auto.service.AutoService;
 import java.util.Collections;
 import java.util.List;
-import org.apache.beam.sdk.managed.ManagedTransformConstants;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
 import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
@@ -53,7 +55,7 @@ public List outputCollectionNames() {
 
   @Override
   public String identifier() {
-    return ManagedTransformConstants.ICEBERG_READ;
+    return getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_READ);
   }
 
   static class IcebergReadSchemaTransform extends SchemaTransform {
diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java
index ea46e8560815..6aa830e7fbc6 100644
--- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java
+++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java
@@ -18,13 +18,14 @@
 package org.apache.beam.sdk.io.iceberg;
 
 import static org.apache.beam.sdk.io.iceberg.IcebergWriteSchemaTransformProvider.Configuration;
+import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
 
 import com.google.auto.service.AutoService;
 import com.google.auto.value.AutoValue;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
-import org.apache.beam.sdk.managed.ManagedTransformConstants;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
 import org.apache.beam.sdk.schemas.AutoValueSchema;
 import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.schemas.Schema;
@@ -151,7 +152,7 @@ public List outputCollectionNames() {
 
   @Override
   public String identifier() {
-    return ManagedTransformConstants.ICEBERG_WRITE;
+    return getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_WRITE);
   }
 
   static class IcebergWriteSchemaTransform extends SchemaTransform {
diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle
index 0ba6fa642a02..ec4654bd88df 100644
--- a/sdks/java/io/kafka/build.gradle
+++ b/sdks/java/io/kafka/build.gradle
@@ -54,6 +54,7 @@ dependencies {
   provided library.java.jackson_dataformat_csv
   permitUnusedDeclared library.java.jackson_dataformat_csv
   implementation project(path: ":sdks:java:core", configuration: "shadow")
+  implementation project(path: ":model:pipeline", configuration: "shadow")
   implementation project(":sdks:java:extensions:avro")
   implementation project(":sdks:java:extensions:protobuf")
   implementation project(":sdks:java:expansion-service")
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
index e87669ab2b0a..a3fd1d8c3fd7 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.io.kafka;
 
 import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
 
 import com.google.auto.service.AutoService;
 import java.io.FileOutputStream;
@@ -34,6 +35,7 @@
 import java.util.Map;
 import java.util.stream.Collectors;
 import org.apache.avro.generic.GenericRecord;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
 import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
 import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
 import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
@@ -103,7 +105,7 @@ public Row apply(byte[] input) {
 
   @Override
   public String identifier() {
-    return "beam:schematransform:org.apache.beam:kafka_read:v1";
+    return getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ);
   }
 
   @Override
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
index 09b338492b47..d6f46b11cb7d 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.sdk.io.kafka;
 
+import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
+
 import com.google.auto.service.AutoService;
 import com.google.auto.value.AutoValue;
 import java.io.Serializable;
@@ -26,6 +28,7 @@
 import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
 import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
 import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
 import org.apache.beam.sdk.metrics.Counter;
@@ -249,7 +252,7 @@ public byte[] apply(Row input) {
 
   @Override
   public @UnknownKeyFor @NonNull @Initialized String identifier() {
-    return "beam:schematransform:org.apache.beam:kafka_write:v1";
+    return getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE);
   }
 
   @Override
diff --git a/sdks/java/managed/build.gradle b/sdks/java/managed/build.gradle
index add0d7f3cc0d..c6e868872246 100644
--- a/sdks/java/managed/build.gradle
+++ b/sdks/java/managed/build.gradle
@@ -28,6 +28,7 @@ ext.summary = """Library that provides managed IOs."""
 
 dependencies {
     implementation project(path: ":sdks:java:core", configuration: "shadow")
+    implementation project(path: ":model:pipeline", configuration: "shadow")
     implementation library.java.vendored_guava_32_1_2_jre
     implementation library.java.slf4j_api
 
diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java
index 911e25cdda14..8477726686ee 100644
--- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java
+++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java
@@ -17,11 +17,14 @@
  */
 package org.apache.beam.sdk.managed;
 
+import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
+
 import com.google.auto.value.AutoValue;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import javax.annotation.Nullable;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
 import org.apache.beam.sdk.coders.RowCoder;
 import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
 import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
@@ -87,13 +90,13 @@ public class Managed {
   // Supported SchemaTransforms
   public static final Map READ_TRANSFORMS =
       ImmutableMap.builder()
-          .put(ICEBERG, ManagedTransformConstants.ICEBERG_READ)
-          .put(KAFKA, ManagedTransformConstants.KAFKA_READ)
+          .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_READ))
+          .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ))
           .build();
   public static final Map WRITE_TRANSFORMS =
       ImmutableMap.builder()
-          .put(ICEBERG, ManagedTransformConstants.ICEBERG_WRITE)
-          .put(KAFKA, ManagedTransformConstants.KAFKA_WRITE)
+          .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_WRITE))
+          .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE))
           .build();
 
   /**
diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java
index 51d0b67b4b89..4cf752747be5 100644
--- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java
+++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java
@@ -17,7 +17,10 @@
  */
 package org.apache.beam.sdk.managed;
 
+import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
+
 import java.util.Map;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 
 /**
@@ -41,12 +44,6 @@ public class ManagedTransformConstants {
   // Standard input PCollection tag
   public static final String INPUT = "input";
 
-  public static final String ICEBERG_READ = "beam:schematransform:org.apache.beam:iceberg_read:v1";
-  public static final String ICEBERG_WRITE =
-      "beam:schematransform:org.apache.beam:iceberg_write:v1";
-  public static final String KAFKA_READ = "beam:schematransform:org.apache.beam:kafka_read:v1";
-  public static final String KAFKA_WRITE = "beam:schematransform:org.apache.beam:kafka_write:v1";
-
   private static final Map KAFKA_READ_MAPPINGS =
       ImmutableMap.builder().put("data_format", "format").build();
 
@@ -55,7 +52,7 @@ public class ManagedTransformConstants {
 
   public static final Map> MAPPINGS =
       ImmutableMap.>builder()
-          .put(KAFKA_READ, KAFKA_READ_MAPPINGS)
-          .put(KAFKA_WRITE, KAFKA_WRITE_MAPPINGS)
+          .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ), KAFKA_READ_MAPPINGS)
+          .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE), KAFKA_WRITE_MAPPINGS)
           .build();
 }
diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py
index 4effc91c3d40..74d9a39bb052 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -38,6 +38,7 @@
 StandardSideInputTypes = beam_runner_api_pb2_urns.StandardSideInputTypes
 StandardUserStateTypes = beam_runner_api_pb2_urns.StandardUserStateTypes
 ExpansionMethods = external_transforms_pb2_urns.ExpansionMethods
+ManagedTransforms = external_transforms_pb2_urns.ManagedTransforms
 MonitoringInfo = metrics_pb2_urns.MonitoringInfo
 MonitoringInfoSpecs = metrics_pb2_urns.MonitoringInfoSpecs
 MonitoringInfoTypeUrns = metrics_pb2_urns.MonitoringInfoTypeUrns
diff --git a/sdks/python/apache_beam/transforms/__init__.py b/sdks/python/apache_beam/transforms/__init__.py
index 4e66a290842c..b8b6839019e8 100644
--- a/sdks/python/apache_beam/transforms/__init__.py
+++ b/sdks/python/apache_beam/transforms/__init__.py
@@ -22,6 +22,7 @@
 from apache_beam.transforms import combiners
 from apache_beam.transforms.core import *
 from apache_beam.transforms.external import *
+from apache_beam.transforms.managed import *
 from apache_beam.transforms.ptransform import *
 from apache_beam.transforms.stats import *
 from apache_beam.transforms.timeutil import TimeDomain
diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py
new file mode 100644
index 000000000000..22ee15b1de1c
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/managed.py
@@ -0,0 +1,182 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Managed Transforms.
+
+This module builds and instantiates turnkey transforms that can be managed by
+the underlying runner. This means the runner can upgrade the transform to a
+more optimal/updated version without requiring the user to do anything. It may
+also replace the transform with something entirely different if it chooses to.
+By default, however, the specified transform will remain unchanged.
+
+Using Managed Transforms
+========================
+Managed turnkey transforms have a defined configuration and can be built using
+an inline :class:`dict` like so::
+
+  results = p | beam.managed.Read(
+                    beam.managed.ICEBERG,
+                    config={"table": "foo",
+                            "catalog_name": "bar",
+                            "catalog_properties": {
+                                "warehouse": "path/to/warehouse",
+                                "catalog-impl": "org.apache.my.CatalogImpl"}})
+
+A YAML configuration file can also be used to build a Managed transform. Say we
+have the following `config.yaml` file::
+
+  topic: "foo"
+  bootstrap_servers: "localhost:1234"
+  format: "AVRO"
+
+Simply provide the location to the file like so::
+
+  input_rows = p | beam.Create(...)
+  input_rows | beam.managed.Write(
+                    beam.managed.KAFKA,
+                    config_url="path/to/config.yaml")
+
+Available transforms
+====================
+Available transforms are:
+
+- **Kafka Read and Write**
+- **Iceberg Read and Write**
+
+**Note:** inputs and outputs need to be PCollection(s) of Beam
+:py:class:`apache_beam.pvalue.Row` elements.
+
+**Note:** Today, all managed transforms are essentially cross-language
+transforms, and Java's ManagedSchemaTransform is used under the hood.
+"""
+
+from typing import Any
+from typing import Dict
+from typing import Optional
+
+import yaml
+
+from apache_beam.portability.common_urns import ManagedTransforms
+from apache_beam.transforms.external import BeamJarExpansionService
+from apache_beam.transforms.external import SchemaAwareExternalTransform
+from apache_beam.transforms.ptransform import PTransform
+
+ICEBERG = "iceberg"
+KAFKA = "kafka"
+_MANAGED_IDENTIFIER = "beam:transform:managed:v1"
+_EXPANSION_SERVICE_JAR_TARGETS = {
+    "sdks:java:io:expansion-service:shadowJar": [KAFKA, ICEBERG],
+}
+
+__all__ = ["ICEBERG", "KAFKA", "Read", "Write"]
+
+
+class Read(PTransform):
+  """Read using Managed Transforms"""
+  _READ_TRANSFORMS = {
+      ICEBERG: ManagedTransforms.Urns.ICEBERG_READ.urn,
+      KAFKA: ManagedTransforms.Urns.KAFKA_READ.urn,
+  }
+
+  def __init__(
+      self,
+      source: str,
+      config: Optional[Dict[str, Any]] = None,
+      config_url: Optional[str] = None,
+      expansion_service=None):
+    super().__init__()
+    self._source = source
+    identifier = self._READ_TRANSFORMS.get(source.lower())
+    if not identifier:
+      raise ValueError(
+          f"An unsupported source was specified: '{source}'. Please specify "
+          f"one of the following sources: {list(self._READ_TRANSFORMS.keys())}")
+
+    self._expansion_service = _resolve_expansion_service(
+        source, identifier, expansion_service)
+    self._underlying_identifier = identifier
+    self._yaml_config = yaml.dump(config)
+    self._config_url = config_url
+
+  def expand(self, input):
+    return input | SchemaAwareExternalTransform(
+        identifier=_MANAGED_IDENTIFIER,
+        expansion_service=self._expansion_service,
+        rearrange_based_on_discovery=True,
+        transform_identifier=self._underlying_identifier,
+        config=self._yaml_config,
+        config_url=self._config_url)
+
+  def default_label(self) -> str:
+    return "Managed Read(%s)" % self._source.upper()
+
+
+class Write(PTransform):
+  """Write using Managed Transforms"""
+  _WRITE_TRANSFORMS = {
+      ICEBERG: ManagedTransforms.Urns.ICEBERG_WRITE.urn,
+      KAFKA: ManagedTransforms.Urns.KAFKA_WRITE.urn,
+  }
+
+  def __init__(
+      self,
+      sink: str,
+      config: Optional[Dict[str, Any]] = None,
+      config_url: Optional[str] = None,
+      expansion_service=None):
+    super().__init__()
+    self._sink = sink
+    identifier = self._WRITE_TRANSFORMS.get(sink.lower())
+    if not identifier:
+      raise ValueError(
+          f"An unsupported sink was specified: '{sink}'. Please specify "
+          f"one of the following sinks: {list(self._WRITE_TRANSFORMS.keys())}")
+
+    self._expansion_service = _resolve_expansion_service(
+        sink, identifier, expansion_service)
+    self._underlying_identifier = identifier
+    self._yaml_config = yaml.dump(config)
+    self._config_url = config_url
+
+  def expand(self, input):
+    return input | SchemaAwareExternalTransform(
+        identifier=_MANAGED_IDENTIFIER,
+        expansion_service=self._expansion_service,
+        rearrange_based_on_discovery=True,
+        transform_identifier=self._underlying_identifier,
+        config=self._yaml_config,
+        config_url=self._config_url)
+
+  def default_label(self) -> str:
+    return "Managed Write(%s)" % self._sink.upper()
+
+
+def _resolve_expansion_service(
+    transform_name: str, identifier: str, expansion_service):
+  if expansion_service:
+    return expansion_service
+
+  default_target = None
+  for gradle_target, transforms in _EXPANSION_SERVICE_JAR_TARGETS.items():
+    if transform_name.lower() in transforms:
+      default_target = gradle_target
+      break
+  if not default_target:
+    raise ValueError(
+        "No expansion service was specified and could not find a "
+        f"default expansion service for {transform_name}: '{identifier}'.")
+  return BeamJarExpansionService(default_target)
diff --git a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py
new file mode 100644
index 000000000000..2d7262bac031
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py
@@ -0,0 +1,70 @@
+import os
+import secrets
+import shutil
+import tempfile
+import time
+import unittest
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+
+@pytest.mark.uses_io_java_expansion_service
+@unittest.skipUnless(
+    os.environ.get('EXPANSION_JARS'),
+    "EXPANSION_JARS environment var is not provided, "
+    "indicating that jars have not been built")
+class ManagedIcebergIT(unittest.TestCase):
+  def setUp(self):
+    self._tempdir = tempfile.mkdtemp()
+    if not os.path.exists(self._tempdir):
+      os.mkdir(self._tempdir)
+    test_warehouse_name = 'test_warehouse_%d_%s' % (
+        int(time.time()), secrets.token_hex(3))
+    self.warehouse_path = os.path.join(self._tempdir, test_warehouse_name)
+    os.mkdir(self.warehouse_path)
+
+  def tearDown(self):
+    shutil.rmtree(self._tempdir, ignore_errors=False)
+
+  def _create_row(self, num: int):
+    return beam.Row(
+        int_=num,
+        str_=str(num),
+        bytes_=bytes(num),
+        bool_=(num % 2 == 0),
+        float_=(num + float(num) / 100))
+
+  def test_write_read_pipeline(self):
+    iceberg_config = {
+        "table": "test.write_read",
+        "catalog_name": "default",
+        "catalog_properties": {
+            "type": "hadoop",
+            "warehouse": f"file://{self.warehouse_path}",
+        }
+    }
+
+    rows = [self._create_row(i) for i in range(100)]
+    expected_dicts = [row.as_dict() for row in rows]
+
+    with beam.Pipeline() as write_pipeline:
+      _ = (
+          write_pipeline
+          | beam.Create(rows)
+          | beam.managed.Write(beam.managed.ICEBERG, config=iceberg_config))
+
+    with beam.Pipeline() as read_pipeline:
+      output_dicts = (
+          read_pipeline
+          | beam.managed.Read(beam.managed.ICEBERG, config=iceberg_config)
+          | beam.Map(lambda row: row._asdict()))
+
+      assert_that(output_dicts, equal_to(expected_dicts))
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 15671eeb145b..b4175ad98e92 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -391,6 +391,7 @@ def get_portability_package_data():
           'sortedcontainers>=2.4.0',
           'typing-extensions>=3.7.0',
           'zstandard>=0.18.0,<1',
+          'pyyaml>=3.12,<7.0.0',
           # Dynamic dependencies must be specified in a separate list, otherwise
           # Dependabot won't be able to parse the main list. Any dynamic
           # dependencies will not receive updates from Dependabot.
@@ -415,7 +416,6 @@ def get_portability_package_data():
               'pandas<2.2.0',
               'parameterized>=0.7.1,<0.10.0',
               'pyhamcrest>=1.9,!=1.10.0,<3.0.0',
-              'pyyaml>=3.12,<7.0.0',
               'requests_mock>=1.7,<2.0',
               'tenacity>=8.0.0,<9',
               'pytest>=7.1.2,<8.0',
@@ -523,7 +523,6 @@ def get_portability_package_data():
           'yaml': [
               'docstring-parser>=0.15,<1.0',
               'jinja2>=3.0,<3.2',
-              'pyyaml>=3.12,<7.0.0',
               'virtualenv-clone>=0.5,<1.0',
               # https://github.com/PiotrDabkowski/Js2Py/issues/317
               'js2py>=0.74,<1; python_version<"3.12"',
diff --git a/sdks/standard_expansion_services.yaml b/sdks/standard_expansion_services.yaml
index ad965c8a1ee3..623e33fe2e7c 100644
--- a/sdks/standard_expansion_services.yaml
+++ b/sdks/standard_expansion_services.yaml
@@ -48,7 +48,7 @@
     # Handwritten Kafka wrappers already exist in apache_beam/io/kafka.py
     - 'beam:schematransform:org.apache.beam:kafka_write:v1'
     - 'beam:schematransform:org.apache.beam:kafka_read:v1'
-    # Not ready to generate
+    # Available through apache_beam.transforms.managed.[Read/Write]
     - 'beam:schematransform:org.apache.beam:iceberg_write:v1'
     - 'beam:schematransform:org.apache.beam:iceberg_read:v1'
 

From 80cba5644c338599144c3b227bc1db6b10a039f0 Mon Sep 17 00:00:00 2001
From: martin trieu 
Date: Tue, 15 Oct 2024 04:44:20 -0700
Subject: [PATCH 23/82] plumb backend worker token to work items (#32777)

---
 .../worker/streaming/ActiveWorkState.java     | 24 ++++++++++++----
 .../dataflow/worker/streaming/Work.java       | 28 +++++++++++++++----
 .../client/grpc/GrpcDirectGetWorkStream.java  |  6 +++-
 3 files changed, 47 insertions(+), 11 deletions(-)

diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
index 4607096dd66a..aec52cd7d9a6 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
@@ -240,18 +240,20 @@ synchronized Optional completeWorkAndGetNextWorkForKey(
     @Nullable Queue workQueue = activeWork.get(shardedKey);
     if (workQueue == null) {
       // Work may have been completed due to clearing of stuck commits.
-      LOG.warn("Unable to complete inactive work for key {} and token {}.", shardedKey, workId);
+      LOG.warn(
+          "Unable to complete inactive work for key={} and token={}.  Work queue for key does not exist.",
+          shardedKey,
+          workId);
       return Optional.empty();
     }
+
     removeCompletedWorkFromQueue(workQueue, shardedKey, workId);
     return getNextWork(workQueue, shardedKey);
   }
 
   private synchronized void removeCompletedWorkFromQueue(
       Queue workQueue, ShardedKey shardedKey, WorkId workId) {
-    // avoid Preconditions.checkState here to prevent eagerly evaluating the
-    // format string parameters for the error message.
-    ExecutableWork completedWork = workQueue.peek();
+    @Nullable ExecutableWork completedWork = workQueue.peek();
     if (completedWork == null) {
       // Work may have been completed due to clearing of stuck commits.
       LOG.warn("Active key {} without work, expected token {}", shardedKey, workId);
@@ -337,8 +339,18 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) {
     writer.println(
         "");
+    // Columns.
     writer.println(
-        "");
+        ""
+            + ""
+            + ""
+            + ""
+            + ""
+            + ""
+            + ""
+            + ""
+            + ""
+            + "");
     // Use StringBuilder because we are appending in loop.
     StringBuilder activeWorkStatus = new StringBuilder();
     int commitsPendingCount = 0;
@@ -366,6 +378,8 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) {
       activeWorkStatus.append(elapsedString(activeWork.getStateStartTime(), now));
       activeWorkStatus.append("\n");
     }
 
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index 03d1e1ae469a..6f97cbca9a80 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -56,7 +56,7 @@
 /**
  * Represents the state of an attempt to process a {@link WorkItem} by executing user code.
  *
- * @implNote Not thread safe, should not be executed or accessed by more than 1 thread at a time.
+ * @implNote Not thread safe, should not be modified by more than 1 thread at a time.
  */
 @NotThreadSafe
 @Internal
@@ -70,7 +70,7 @@ public final class Work implements RefreshableWork {
   private final Map totalDurationPerState;
   private final WorkId id;
   private final String latencyTrackingId;
-  private TimedState currentState;
+  private volatile TimedState currentState;
   private volatile boolean isFailed;
   private volatile String processingThreadName = "";
 
@@ -111,7 +111,18 @@ public static ProcessingContext createProcessingContext(
       GetDataClient getDataClient,
       Consumer workCommitter,
       HeartbeatSender heartbeatSender) {
-    return ProcessingContext.create(computationId, getDataClient, workCommitter, heartbeatSender);
+    return ProcessingContext.create(
+        computationId, getDataClient, workCommitter, heartbeatSender, /* backendWorkerToken= */ "");
+  }
+
+  public static ProcessingContext createProcessingContext(
+      String computationId,
+      GetDataClient getDataClient,
+      Consumer workCommitter,
+      HeartbeatSender heartbeatSender,
+      String backendWorkerToken) {
+    return ProcessingContext.create(
+        computationId, getDataClient, workCommitter, heartbeatSender, backendWorkerToken);
   }
 
   private static LatencyAttribution.Builder createLatencyAttributionWithActiveLatencyBreakdown(
@@ -168,6 +179,10 @@ public GlobalData fetchSideInput(GlobalDataRequest request) {
     return processingContext.getDataClient().getSideInputData(request);
   }
 
+  public String backendWorkerToken() {
+    return processingContext.backendWorkerToken();
+  }
+
   public Watermarks watermarks() {
     return watermarks;
   }
@@ -351,9 +366,10 @@ private static ProcessingContext create(
         String computationId,
         GetDataClient getDataClient,
         Consumer workCommitter,
-        HeartbeatSender heartbeatSender) {
+        HeartbeatSender heartbeatSender,
+        String backendWorkerToken) {
       return new AutoValue_Work_ProcessingContext(
-          computationId, getDataClient, heartbeatSender, workCommitter);
+          computationId, getDataClient, heartbeatSender, workCommitter, backendWorkerToken);
     }
 
     /** Computation that the {@link Work} belongs to. */
@@ -370,6 +386,8 @@ private static ProcessingContext create(
      */
     public abstract Consumer workCommitter();
 
+    public abstract String backendWorkerToken();
+
     private Optional fetchKeyedState(KeyedGetDataRequest request) {
       return Optional.ofNullable(getDataClient().getStateData(computationId(), request));
     }
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
index 45d010d7cfac..19de998b1da8 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
@@ -254,7 +254,11 @@ private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
 
   private Work.ProcessingContext createProcessingContext(String computationId) {
     return Work.createProcessingContext(
-        computationId, getDataClient.get(), workCommitter.get()::commit, heartbeatSender.get());
+        computationId,
+        getDataClient.get(),
+        workCommitter.get()::commit,
+        heartbeatSender.get(),
+        backendWorkerToken());
   }
 
   @Override

From 45490aca9ace73e03d6c42e5d2c8668267daade4 Mon Sep 17 00:00:00 2001
From: Hai Joey Tran 
Date: Tue, 15 Oct 2024 10:27:16 -0400
Subject: [PATCH 24/82] Polish .with_exception_handling docstring (#32739)

* replace 'record' with 'input' and fix example

* more tweaking
---
 sdks/python/apache_beam/transforms/core.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 8a5bb00eeb98..be3cec6304f4 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1595,7 +1595,7 @@ def with_exception_handling(
       error_handler=None,
       on_failure_callback: typing.Optional[typing.Callable[
           [Exception, typing.Any], None]] = None):
-    """Automatically provides a dead letter output for skipping bad records.
+    """Automatically provides a dead letter output for saving bad inputs.
     This can allow a pipeline to continue successfully rather than fail or
     continuously throw errors on retry when bad elements are encountered.
 
@@ -1606,17 +1606,18 @@ def with_exception_handling(
 
     For example, one would write::
 
-        good, bad = Map(maybe_error_raising_function).with_exception_handling()
+        good, bad = inputs | Map(maybe_erroring_fn).with_exception_handling()
 
     and `good` will be a PCollection of mapped records and `bad` will contain
-    those that raised exceptions.
+    tuples of the form `(input, error_string`) for each input that raised an
+    exception.
 
 
     Args:
       main_tag: tag to be used for the main (good) output of the DoFn,
           useful to avoid possible conflicts if this DoFn already produces
           multiple outputs.  Optional, defaults to 'good'.
-      dead_letter_tag: tag to be used for the bad records, useful to avoid
+      dead_letter_tag: tag to be used for the bad inputs, useful to avoid
           possible conflicts if this DoFn already produces multiple outputs.
           Optional, defaults to 'bad'.
       exc_class: An exception class, or tuple of exception classes, to catch.
@@ -1635,9 +1636,9 @@ def with_exception_handling(
           than a new process per element, so the overhead should be minimal
           (and can be amortized if there's any per-process or per-bundle
           initialization that needs to be done). Optional, defaults to False.
-      threshold: An upper bound on the ratio of records that can be bad before
+      threshold: An upper bound on the ratio of inputs that can be bad before
           aborting the entire pipeline. Optional, defaults to 1.0 (meaning
-          up to 100% of records can be bad and the pipeline will still succeed).
+          up to 100% of inputs can be bad and the pipeline will still succeed).
       threshold_windowing: Event-time windowing to use for threshold. Optional,
           defaults to the windowing of the input.
       timeout: If the element has not finished processing in timeout seconds,

From 8fa19f24b5e1efa9d3b788432e1556c661fab64d Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 15 Oct 2024 10:27:53 -0400
Subject: [PATCH 25/82] Bump github.com/aws/aws-sdk-go-v2 from 1.32.1 to 1.32.2
 in /sdks (#32711)

Bumps [github.com/aws/aws-sdk-go-v2](https://github.com/aws/aws-sdk-go-v2) from 1.32.1 to 1.32.2.
- [Release notes](https://github.com/aws/aws-sdk-go-v2/releases)
- [Commits](https://github.com/aws/aws-sdk-go-v2/compare/v1.32.1...v1.32.2)

---
updated-dependencies:
- dependency-name: github.com/aws/aws-sdk-go-v2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] 
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
---
 sdks/go.mod | 2 +-
 sdks/go.sum | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/sdks/go.mod b/sdks/go.mod
index 223ee2062b81..04362a5c43a3 100644
--- a/sdks/go.mod
+++ b/sdks/go.mod
@@ -30,7 +30,7 @@ require (
 	cloud.google.com/go/pubsub v1.43.0
 	cloud.google.com/go/spanner v1.67.0
 	cloud.google.com/go/storage v1.44.0
-	github.com/aws/aws-sdk-go-v2 v1.32.1
+	github.com/aws/aws-sdk-go-v2 v1.32.2
 	github.com/aws/aws-sdk-go-v2/config v1.27.42
 	github.com/aws/aws-sdk-go-v2/credentials v1.17.40
 	github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.28
diff --git a/sdks/go.sum b/sdks/go.sum
index 515c0c07a39a..ebeba7862d24 100644
--- a/sdks/go.sum
+++ b/sdks/go.sum
@@ -689,8 +689,8 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve
 github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo=
 github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0=
 github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250=
-github.com/aws/aws-sdk-go-v2 v1.32.1 h1:8WuZ43ytA+TV6QEPT/R23mr7pWyI7bSSiEHdt9BS2Pw=
-github.com/aws/aws-sdk-go-v2 v1.32.1/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo=
+github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI=
+github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo=
 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0=
 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA=
 github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA=

From ade80d5cb3a311ce75261238a10af76469cb07b1 Mon Sep 17 00:00:00 2001
From: Shunping Huang 
Date: Tue, 15 Oct 2024 10:28:49 -0400
Subject: [PATCH 26/82] enable ordered list state (#32755)

---
 sdks/python/apache_beam/transforms/environments.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py
index dbb227802925..77704e0522b2 100644
--- a/sdks/python/apache_beam/transforms/environments.py
+++ b/sdks/python/apache_beam/transforms/environments.py
@@ -895,6 +895,7 @@ def _python_sdk_capabilities_iter():
   yield common_urns.primitives.TO_STRING.urn
   yield common_urns.protocols.DATA_SAMPLING.urn
   yield common_urns.protocols.SDK_CONSUMING_RECEIVED_DATA.urn
+  yield common_urns.protocols.ORDERED_LIST_STATE.urn
 
 
 def python_sdk_dependencies(options, tmp_dir=None):

From 547cc3989f3dacbc7510b59e0fae91ad31b14d27 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 15 Oct 2024 09:21:11 -0700
Subject: [PATCH 27/82] Bump github.com/aws/aws-sdk-go-v2/feature/s3/manager in
 /sdks (#32768)

Bumps [github.com/aws/aws-sdk-go-v2/feature/s3/manager](https://github.com/aws/aws-sdk-go-v2) from 1.17.28 to 1.17.32.
- [Release notes](https://github.com/aws/aws-sdk-go-v2/releases)
- [Commits](https://github.com/aws/aws-sdk-go-v2/compare/credentials/v1.17.28...credentials/v1.17.32)

---
updated-dependencies:
- dependency-name: github.com/aws/aws-sdk-go-v2/feature/s3/manager
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] 
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
---
 sdks/go.mod | 28 +++++++++++++--------------
 sdks/go.sum | 56 ++++++++++++++++++++++++++---------------------------
 2 files changed, 42 insertions(+), 42 deletions(-)

diff --git a/sdks/go.mod b/sdks/go.mod
index 04362a5c43a3..0b5ac98df404 100644
--- a/sdks/go.mod
+++ b/sdks/go.mod
@@ -31,10 +31,10 @@ require (
 	cloud.google.com/go/spanner v1.67.0
 	cloud.google.com/go/storage v1.44.0
 	github.com/aws/aws-sdk-go-v2 v1.32.2
-	github.com/aws/aws-sdk-go-v2/config v1.27.42
-	github.com/aws/aws-sdk-go-v2/credentials v1.17.40
-	github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.28
-	github.com/aws/aws-sdk-go-v2/service/s3 v1.65.0
+	github.com/aws/aws-sdk-go-v2/config v1.27.43
+	github.com/aws/aws-sdk-go-v2/credentials v1.17.41
+	github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.32
+	github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3
 	github.com/aws/smithy-go v1.22.0
 	github.com/docker/go-connections v0.5.0
 	github.com/dustin/go-humanize v1.0.1
@@ -131,18 +131,18 @@ require (
 	github.com/apache/thrift v0.17.0 // indirect
 	github.com/aws/aws-sdk-go v1.34.0 // indirect
 	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect
-	github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.16 // indirect
-	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.20 // indirect
-	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.20 // indirect
+	github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 // indirect
 	github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
-	github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.19 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 // indirect
 	github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect
-	github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.0 // indirect
-	github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.1 // indirect
-	github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.0 // indirect
-	github.com/aws/aws-sdk-go-v2/service/sso v1.24.1 // indirect
-	github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.1 // indirect
-	github.com/aws/aws-sdk-go-v2/service/sts v1.32.1 // indirect
+	github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 // indirect
+	github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 // indirect
+	github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 // indirect
+	github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 // indirect
+	github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 // indirect
+	github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 // indirect
 	github.com/cenkalti/backoff/v4 v4.2.1 // indirect
 	github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
 	github.com/cespare/xxhash/v2 v2.3.0 // indirect
diff --git a/sdks/go.sum b/sdks/go.sum
index ebeba7862d24..db6d71b061b5 100644
--- a/sdks/go.sum
+++ b/sdks/go.sum
@@ -694,48 +694,48 @@ github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v
 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0=
 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA=
 github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA=
-github.com/aws/aws-sdk-go-v2/config v1.27.42 h1:Zsy9coUPuOsCWkjTvHpl2/DB9bptXtv7WeNPxvFr87s=
-github.com/aws/aws-sdk-go-v2/config v1.27.42/go.mod h1:FGASs+PuJM2EY+8rt8qyQKLPbbX/S5oY+6WzJ/KE7ko=
+github.com/aws/aws-sdk-go-v2/config v1.27.43 h1:p33fDDihFC390dhhuv8nOmX419wjOSDQRb+USt20RrU=
+github.com/aws/aws-sdk-go-v2/config v1.27.43/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc=
 github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc=
-github.com/aws/aws-sdk-go-v2/credentials v1.17.40 h1:RjnlA7t0p/IamxAM7FUJ5uS13Vszh4sjVGvsx91tGro=
-github.com/aws/aws-sdk-go-v2/credentials v1.17.40/go.mod h1:dgpdnSs1Bp/atS6vLlW83h9xZPP+uSPB/27dFSgC1BM=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.41 h1:7gXo+Axmp+R4Z+AK8YFQO0ZV3L0gizGINCOWxSLY9W8=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.41/go.mod h1:u4Eb8d3394YLubphT4jLEwN1rLNq2wFOlT6OuxFwPzU=
 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw=
-github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.16 h1:fwrer1pJeaiia0CcOfWVbZxvj9Adc7rsuaMTwPR0DIA=
-github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.16/go.mod h1:XyEwwp8XI4zMar7MTnJ0Sk7qY/9aN8Hp929XhuX5SF8=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 h1:TMH3f/SCAWdNtXXVPPu5D6wrr4G5hI1rAxbcocKfC7Q=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17/go.mod h1:1ZRXLdTpzdJb9fwTMXiLipENRxkGMTn1sfKexGllQCw=
 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4=
-github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.28 h1:yUPy1fwOKNZ9L52E9TCMomU+mKXNCgqi17dtYIdSolk=
-github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.28/go.mod h1:bJJP1cGMO0fPBgCjqHAWbc0WRbKrxrWU4hQfc/0ciAA=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.20 h1:OErdlGnt+hg3tTwGYAlKvFkKVUo/TXkoHcxDxuhYYU8=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.20/go.mod h1:HsPfuL5gs+407ByRXBMgpYoyrV1sgMrzd18yMXQHJpo=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.20 h1:822cE1CYSwY/EZnErlF46pyynuxvf1p+VydHRQW+XNs=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.20/go.mod h1:79/Tn7H7hYC5Gjz6fbnOV4OeBpkao7E8Tv95RO72pMM=
+github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.32 h1:C2hE+gJ40Cb4vzhFJ+tTzjvBpPloUq7XP6PD3A2Fk7g=
+github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.32/go.mod h1:0OmMtVNp+10JFBTfmA2AIeqBDm0YthDXmE+N7poaptk=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 h1:UAsR3xA31QGf79WzpG/ixT9FZvQlh5HY1NRqSHBNOCk=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21/go.mod h1:JNr43NFf5L9YaG3eKTm7HQzls9J+A9YYcGI5Quh1r2Y=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 h1:6jZVETqmYCadGFvrYEQfC5fAQmlo80CeL5psbno6r0s=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21/go.mod h1:1SR0GbLlnN3QUmYaflZNiH1ql+1qrSiB2vwcJ+4UM60=
 github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg=
 github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ=
 github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc=
-github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.19 h1:FKdiFzTxlTRO71p0C7VrLbkkdW8qfMKF5+ej6bTmkT0=
-github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.19/go.mod h1:abO3pCj7WLQPTllnSeYImqFfkGrmJV0JovWo/gqT5N0=
+github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 h1:7edmS3VOBDhK00b/MwGtGglCm7hhwNYnjJs/PgFdMQE=
+github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21/go.mod h1:Q9o5h4HoIWG8XfzxqiuK/CGUbepCJ8uTlaE3bAbxytQ=
 github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE=
 github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g=
 github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ=
-github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.0 h1:FQNWhRuSq8QwW74GtU0MrveNhZbqvHsA4dkA9w8fTDQ=
-github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.0/go.mod h1:j/zZ3zmWfGCK91K73YsfHP53BSTLSjL/y6YN39XbBLM=
+github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 h1:4FMHqLfk0efmTqhXVRL5xYRqlEBNBiRI7N6w4jsEdd4=
+github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2/go.mod h1:LWoqeWlK9OZeJxsROW2RqrSPvQHKTpp69r/iDjwsSaw=
 github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw=
-github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.1 h1:5vBMBTakOvtd8aNaicswcrr9qqCYUlasuzyoU6/0g8I=
-github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.1/go.mod h1:WSUbDa5qdg05Q558KXx2Scb+EDvOPXT9gfET0fyrJSk=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 h1:s7NA1SOw8q/5c0wr8477yOPp0z+uBaXBnLE0XYb0POA=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2/go.mod h1:fnjjWyAW/Pj5HYOxl9LJqWtEwS7W2qgcRLWP+uWbss0=
 github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ=
-github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.0 h1:1NKXS8XfhMM0bg5wVYa/eOH8AM2f6JijugbKEyQFTIg=
-github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.0/go.mod h1:ph931DUfVfgrhZR7py9olSvHCiRpvaGxNvlWBcXxFds=
+github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 h1:t7iUP9+4wdc5lt3E41huP+GvQZJD38WLsgVp4iOtAjg=
+github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2/go.mod h1:/niFCtmuQNxqx9v8WAPq5qh7EH25U4BF6tjoyq9bObM=
 github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI=
-github.com/aws/aws-sdk-go-v2/service/s3 v1.65.0 h1:2dSm7frMrw2tdJ0QvyccQNJyPGaP24dyDgZ6h1QJMGU=
-github.com/aws/aws-sdk-go-v2/service/s3 v1.65.0/go.mod h1:4XSVpw66upN8wND3JZA29eXl2NOZvfFVq7DIP6xvfuQ=
+github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3 h1:xxHGZ+wUgZNACQmxtdvP5tgzfsxGS3vPpTP5Hy3iToE=
+github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3/go.mod h1:cB6oAuus7YXRZhWCc1wIwPywwZ1XwweNp2TVAEGYeB8=
 github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM=
-github.com/aws/aws-sdk-go-v2/service/sso v1.24.1 h1:aAIr0WhAgvKrxZtkBqne87Gjmd7/lJVTFkR2l2yuhL8=
-github.com/aws/aws-sdk-go-v2/service/sso v1.24.1/go.mod h1:8XhxGMWUfikJuginPQl5SGZ0LSJuNX3TCEQmFWZwHTM=
-github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.1 h1:J6kIsIkgFOaU6aKjigXJoue1XEHtKIIrpSh4vKdmRTs=
-github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.1/go.mod h1:2V2JLP7tXOmUbL3Hd1ojq+774t2KUAEQ35//shoNEL0=
+github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 h1:bSYXVyUzoTHoKalBmwaZxs97HU9DWWI3ehHSAMa7xOk=
+github.com/aws/aws-sdk-go-v2/service/sso v1.24.2/go.mod h1:skMqY7JElusiOUjMJMOv1jJsP7YUg7DrhgqZZWuzu1U=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 h1:AhmO1fHINP9vFYUE0LHzCWg/LfUWUF+zFPEcY9QXb7o=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2/go.mod h1:o8aQygT2+MVP0NaV6kbdE1YnnIM8RRVQzoeUH45GOdI=
 github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg=
-github.com/aws/aws-sdk-go-v2/service/sts v1.32.1 h1:q76Ig4OaJzVJGNUSGO3wjSTBS94g+EhHIbpY9rPvkxs=
-github.com/aws/aws-sdk-go-v2/service/sts v1.32.1/go.mod h1:664dajZ7uS7JMUMUG0R5bWbtN97KECNCVdFDdQ6Ipu8=
+github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 h1:CiS7i0+FUe+/YY1GvIBLLrR/XNGZ4CtM1Ll0XavNuVo=
+github.com/aws/aws-sdk-go-v2/service/sts v1.32.2/go.mod h1:HtaiBI8CjYoNVde8arShXb94UbQQi9L4EMr6D+xGBwo=
 github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E=
 github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM=
 github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=

From 89dd0887a4341fa4a1f59fdf413d35372a1fdffe Mon Sep 17 00:00:00 2001
From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com>
Date: Tue, 15 Oct 2024 20:42:00 +0300
Subject: [PATCH 28/82] Add license to fix RAT failure (#32785)

---
 .../transforms/managed_iceberg_it_test.py       | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py
index 2d7262bac031..0dfa2aa19c51 100644
--- a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py
+++ b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
 import os
 import secrets
 import shutil

From 06ecee96e7541d108613b30ecc09abfcaf9929bd Mon Sep 17 00:00:00 2001
From: Danny McCormick 
Date: Tue, 15 Oct 2024 14:19:04 -0400
Subject: [PATCH 29/82] vLLM model handler efficiency improvements (#32687)

* vLLM model handler efficiency improvements

* fmt

* Remove bad exceptions

* lint

* lint
---
 .../trigger_files/beam_PostCommit_Python.json |   2 +-
 .../ml/inference/vllm_inference.py            | 103 ++++++++++++------
 2 files changed, 71 insertions(+), 34 deletions(-)

diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json
index 1eb60f6e4959..9e1d1e1b80dd 100644
--- a/.github/trigger_files/beam_PostCommit_Python.json
+++ b/.github/trigger_files/beam_PostCommit_Python.json
@@ -1,5 +1,5 @@
 {
   "comment": "Modify this file in a trivial way to cause this test suite to run.",
-  "modification": 3
+  "modification": 4
 }
 
diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py
index 28890083d93e..e1ba4f49b8fd 100644
--- a/sdks/python/apache_beam/ml/inference/vllm_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py
@@ -17,6 +17,7 @@
 
 # pytype: skip-file
 
+import asyncio
 import logging
 import os
 import subprocess
@@ -35,6 +36,7 @@
 from apache_beam.ml.inference.base import ModelHandler
 from apache_beam.ml.inference.base import PredictionResult
 from apache_beam.utils import subprocess_server
+from openai import AsyncOpenAI
 from openai import OpenAI
 
 try:
@@ -94,6 +96,15 @@ def getVLLMClient(port) -> OpenAI:
   )
 
 
+def getAsyncVLLMClient(port) -> AsyncOpenAI:
+  openai_api_key = "EMPTY"
+  openai_api_base = f"http://localhost:{port}/v1"
+  return AsyncOpenAI(
+      api_key=openai_api_key,
+      base_url=openai_api_base,
+  )
+
+
 class _VLLMModelServer():
   def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]):
     self._model_name = model_name
@@ -184,6 +195,34 @@ def __init__(
   def load_model(self) -> _VLLMModelServer:
     return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)
 
+  async def _async_run_inference(
+      self,
+      batch: Sequence[str],
+      model: _VLLMModelServer,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    client = getAsyncVLLMClient(model.get_server_port())
+    inference_args = inference_args or {}
+    async_predictions = []
+    for prompt in batch:
+      try:
+        completion = client.completions.create(
+            model=self._model_name, prompt=prompt, **inference_args)
+        async_predictions.append(completion)
+      except Exception as e:
+        model.check_connectivity()
+        raise e
+
+    predictions = []
+    for p in async_predictions:
+      try:
+        predictions.append(await p)
+      except Exception as e:
+        model.check_connectivity()
+        raise e
+
+    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
   def run_inference(
       self,
       batch: Sequence[str],
@@ -200,22 +239,7 @@ def run_inference(
     Returns:
       An Iterable of type PredictionResult.
     """
-    client = getVLLMClient(model.get_server_port())
-    inference_args = inference_args or {}
-    predictions = []
-    # TODO(https://github.com/apache/beam/issues/32528): We should add support
-    # for taking in batches and doing a bunch of async calls. That will end up
-    # being more efficient when we can do in bundle batching.
-    for prompt in batch:
-      try:
-        completion = client.completions.create(
-            model=self._model_name, prompt=prompt, **inference_args)
-        predictions.append(completion)
-      except Exception as e:
-        model.check_connectivity()
-        raise e
-
-    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+    return asyncio.run(self._async_run_inference(batch, model, inference_args))
 
   def share_model_across_processes(self) -> bool:
     return True
@@ -272,28 +296,15 @@ def load_model(self) -> _VLLMModelServer:
 
     return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)
 
-  def run_inference(
+  async def _async_run_inference(
       self,
       batch: Sequence[Sequence[OpenAIChatMessage]],
       model: _VLLMModelServer,
       inference_args: Optional[Dict[str, Any]] = None
   ) -> Iterable[PredictionResult]:
-    """Runs inferences on a batch of text strings.
-
-    Args:
-      batch: A sequence of examples as OpenAI messages.
-      model: A _VLLMModelServer for connecting to the spun up server.
-      inference_args: Any additional arguments for an inference.
-
-    Returns:
-      An Iterable of type PredictionResult.
-    """
-    client = getVLLMClient(model.get_server_port())
+    client = getAsyncVLLMClient(model.get_server_port())
     inference_args = inference_args or {}
-    predictions = []
-    # TODO(https://github.com/apache/beam/issues/32528): We should add support
-    # for taking in batches and doing a bunch of async calls. That will end up
-    # being more efficient when we can do in bundle batching.
+    async_predictions = []
     for messages in batch:
       formatted = []
       for message in messages:
@@ -301,12 +312,38 @@ def run_inference(
       try:
         completion = client.chat.completions.create(
             model=self._model_name, messages=formatted, **inference_args)
-        predictions.append(completion)
+        async_predictions.append(completion)
+      except Exception as e:
+        model.check_connectivity()
+        raise e
+
+    predictions = []
+    for p in async_predictions:
+      try:
+        predictions.append(await p)
       except Exception as e:
         model.check_connectivity()
         raise e
 
     return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
 
+  def run_inference(
+      self,
+      batch: Sequence[Sequence[OpenAIChatMessage]],
+      model: _VLLMModelServer,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """Runs inferences on a batch of text strings.
+
+    Args:
+      batch: A sequence of examples as OpenAI messages.
+      model: A _VLLMModelServer for connecting to the spun up server.
+      inference_args: Any additional arguments for an inference.
+
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    return asyncio.run(self._async_run_inference(batch, model, inference_args))
+
   def share_model_across_processes(self) -> bool:
     return True

From 1b744a7fbc2725ed98a130027a042c3adc10b96b Mon Sep 17 00:00:00 2001
From: claudevdm <33973061+claudevdm@users.noreply.github.com>
Date: Tue, 15 Oct 2024 14:22:40 -0400
Subject: [PATCH 30/82] Created using Colab (#32789)

---
 bigquery_enrichment_transform.ipynb | 781 ++++++++++++++++++++++++++++
 1 file changed, 781 insertions(+)
 create mode 100644 bigquery_enrichment_transform.ipynb

diff --git a/bigquery_enrichment_transform.ipynb b/bigquery_enrichment_transform.ipynb
new file mode 100644
index 000000000000..331ecb9ba93d
--- /dev/null
+++ b/bigquery_enrichment_transform.ipynb
@@ -0,0 +1,781 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "include_colab_link": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "view-in-github",
+        "colab_type": "text"
+      },
+      "source": [
+        "\"Open"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+        "\n",
+        "# Licensed to the Apache Software Foundation (ASF) under one\n",
+        "# or more contributor license agreements. See the NOTICE file\n",
+        "# distributed with this work for additional information\n",
+        "# regarding copyright ownership. The ASF licenses this file\n",
+        "# to you under the Apache License, Version 2.0 (the\n",
+        "# \"License\"); you may not use this file except in compliance\n",
+        "# with the License. You may obtain a copy of the License at\n",
+        "#\n",
+        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing,\n",
+        "# software distributed under the License is distributed on an\n",
+        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+        "# KIND, either express or implied. See the License for the\n",
+        "# specific language governing permissions and limitations\n",
+        "# under the License"
+      ],
+      "metadata": {
+        "id": "55h6JBJeJGqg",
+        "cellView": "form"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Use Apache Beam and BigQuery to enrich data\n",
+        "\n",
+        "
KeyTokenQueuedActive ForStateState Active ForProcessing Thread
KeyTokenQueuedActive ForStateState Active ForProcessing ThreadBackend
"); activeWorkStatus.append(activeWork.getProcessingThreadName()); + activeWorkStatus.append(""); + activeWorkStatus.append(activeWork.backendWorkerToken()); activeWorkStatus.append("
\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
\n" + ], + "metadata": { + "id": "YrOuxMeKJZxC" + } + }, + { + "cell_type": "markdown", + "source": [ + "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [BigQuery](https://cloud.google.com/bigquery/docs/overview). The enrichment transform is an Apache Beam turnkey transform that lets you enrich data by using a key-value lookup. This transform has the following features:\n", + "\n", + "- The transform has a built-in Apache Beam handler that interacts with BigQuery data during enrichment.\n", + "- The enrichment transform uses client-side throttling to rate limit the requests. The default retry strategy uses exponential backoff. You can configure rate limiting to suit your use case.\n", + "\n", + "This notebook demonstrates the following telecommunications company use case:\n", + "\n", + "A telecom company wants to predict which customers are likely to cancel their subscriptions so that the company can proactively offer these customers incentives to stay. The example uses customer demographic data and usage data stored in BigQuery to enrich a stream of customer IDs. The enriched data is then used to predict the likelihood of customer churn.\n", + "\n", + "## Before you begin\n", + "Set up your environment and download dependencies.\n", + "\n", + "### Install Apache Beam\n", + "To use the enrichment transform with the built-in BigQuery handler, install the Apache Beam SDK version 2.57.0 or later." + ], + "metadata": { + "id": "pf2bL-PmJScZ" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install torch\n", + "!pip install apache_beam[interactive,gcp]==2.57.0 --quiet" + ], + "metadata": { + "id": "oVbWf73FJSzf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Import the following modules:\n", + "- Pub/Sub for streaming data\n", + "- BigQuery for enrichment\n", + "- Apache Beam for running the streaming pipeline\n", + "- PyTorch to predict customer churn" + ], + "metadata": { + "id": "siSUsfR5tKX9" + } + }, + { + "cell_type": "code", + "source": [ + "import datetime\n", + "import json\n", + "import math\n", + "\n", + "from typing import Any\n", + "from typing import Dict\n", + "\n", + "import torch\n", + "from google.cloud import pubsub_v1\n", + "from google.cloud import bigquery\n", + "from google.api_core.exceptions import Conflict\n", + "\n", + "import apache_beam as beam\n", + "import apache_beam.runners.interactive.interactive_beam as ib\n", + "from apache_beam.ml.inference.base import KeyedModelHandler\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n", + "from apache_beam.options import pipeline_options\n", + "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n", + "from apache_beam.transforms.enrichment import Enrichment\n", + "from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler\n", + "\n", + "import pandas as pd\n", + "\n", + "from sklearn.preprocessing import LabelEncoder" + ], + "metadata": { + "id": "p6bruDqFJkXE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Authenticate with Google Cloud\n", + "This notebook reads data from Pub/Sub and BigQuery. To use your Google Cloud account, authenticate this notebook.\n", + "To prepare for this step, replace `` with your Google Cloud project ID." + ], + "metadata": { + "id": "t0QfhuUlJozO" + } + }, + { + "cell_type": "code", + "source": [ + "PROJECT_ID = \"\"\n" + ], + "metadata": { + "id": "RwoBZjD1JwnD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user(project_id=PROJECT_ID)" + ], + "metadata": { + "id": "rVAyQxoeKflB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Set up the BigQuery tables\n", + "\n", + "Create sample BigQuery tables for this notebook.\n", + "\n", + "- Replace `` with the name of your BigQuery dataset. Only letters (uppercase or lowercase), numbers, and underscores are allowed.\n", + "- If the dataset does not exist, a new dataset with this ID is created." + ], + "metadata": { + "id": "1vDwknoHKoa-" + } + }, + { + "cell_type": "code", + "source": [ + "DATASET_ID = \"\"\n", + "\n", + "CUSTOMERS_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.customers'\n", + "USAGE_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.usage'" + ], + "metadata": { + "id": "UxeGFqSJu-G6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Create customer and usage tables, and insert fake data." + ], + "metadata": { + "id": "Gw4RfZavyfpo" + } + }, + { + "cell_type": "code", + "source": [ + "client = bigquery.Client(project=PROJECT_ID)\n", + "\n", + "# Create dataset if it does not exist.\n", + "client.create_dataset(bigquery.Dataset(f\"{PROJECT_ID}.{DATASET_ID}\"), exists_ok=True)\n", + "print(f\"Created dataset {DATASET_ID}\")\n", + "\n", + "# Prepare the fake customer data.\n", + "customer_data = {\n", + " 'customer_id': [1, 2, 3, 4, 5],\n", + " 'age': [35, 28, 45, 62, 22],\n", + " 'plan': ['Gold', 'Silver', 'Bronze', 'Gold', 'Silver'],\n", + " 'contract_length': [12, 24, 6, 36, 12]\n", + "}\n", + "\n", + "customers_df = pd.DataFrame(customer_data)\n", + "\n", + "# Insert customer data.\n", + "job_config = bigquery.LoadJobConfig(\n", + " schema=[\n", + " bigquery.SchemaField(\"customer_id\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"age\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"plan\", \"STRING\"),\n", + " bigquery.SchemaField(\"contract_length\", \"INTEGER\"),\n", + " ],\n", + " write_disposition=\"WRITE_TRUNCATE\",\n", + ")\n", + "\n", + "job = client.load_table_from_dataframe(\n", + " customers_df, CUSTOMERS_TABLE_ID, job_config=job_config\n", + ")\n", + "job.result() # Wait for the job to complete.\n", + "print(f\"Customers table created and populated: {CUSTOMERS_TABLE_ID}\")\n", + "\n", + "# Prepare the fake usage data.\n", + "usage_data = {\n", + " 'customer_id': [1, 1, 2, 2, 3, 3, 4, 4, 5, 5],\n", + " 'date': pd.to_datetime(['2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01']),\n", + " 'calls_made': [50, 65, 20, 18, 100, 110, 30, 28, 60, 70],\n", + " 'data_usage_gb': [10, 12, 5, 4, 20, 22, 8, 7, 15, 18]\n", + "}\n", + "usage_df = pd.DataFrame(usage_data)\n", + "\n", + "# Insert usage data.\n", + "job_config = bigquery.LoadJobConfig(\n", + " schema=[\n", + " bigquery.SchemaField(\"customer_id\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"date\", \"DATE\"),\n", + " bigquery.SchemaField(\"calls_made\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"data_usage_gb\", \"FLOAT\"),\n", + " ],\n", + " write_disposition=\"WRITE_TRUNCATE\",\n", + ")\n", + "job = client.load_table_from_dataframe(\n", + " usage_df, USAGE_TABLE_ID, job_config=job_config\n", + ")\n", + "job.result() # Wait for the job to complete.\n", + "\n", + "print(f\"Usage table created and populated: {USAGE_TABLE_ID}\")" + ], + "metadata": { + "id": "-QRZC4v0KipK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Train the model" + ], + "metadata": { + "id": "PZCjCzxaLOJt" + } + }, + { + "cell_type": "markdown", + "source": [ + "Create sample data and train a simple model for churn prediction." + ], + "metadata": { + "id": "R4dIHclDLfIj" + } + }, + { + "cell_type": "code", + "source": [ + "# Create fake training data\n", + "data = {\n", + " 'customer_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n", + " 'age': [35, 28, 45, 62, 22, 38, 55, 25, 40, 30],\n", + " 'plan': ['Gold', 'Silver', 'Bronze', 'Gold', 'Silver', 'Bronze', 'Gold', 'Silver', 'Bronze', 'Silver'],\n", + " 'contract_length': [12, 24, 6, 36, 12, 18, 30, 12, 24, 18],\n", + " 'avg_monthly_calls': [57.5, 19, 100, 30, 60, 45, 25, 70, 50, 35],\n", + " 'avg_monthly_data_usage_gb': [11, 4.5, 20, 8, 15, 10, 7, 18, 12, 8],\n", + " 'churned': [0, 0, 1, 0, 1, 0, 0, 1, 0, 1] # Target variable\n", + "}\n", + "plan_encoder = LabelEncoder()\n", + "plan_encoder.fit(data['plan'])\n", + "df = pd.DataFrame(data)\n", + "df['plan'] = plan_encoder.transform(data['plan'])\n", + "\n" + ], + "metadata": { + "id": "YoMjdqJ1KxOM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Preprocess the data:\n", + "\n", + "1. Convert the lists to tensors.\n", + "2. Separate the features from the expected prediction." + ], + "metadata": { + "id": "EgIFJx76MF3v" + } + }, + { + "cell_type": "code", + "source": [ + "features = ['age', 'plan', 'contract_length', 'avg_monthly_calls', 'avg_monthly_data_usage_gb']\n", + "target = 'churned'\n", + "\n", + "X = torch.tensor(df[features].values, dtype=torch.float)\n", + "Y = torch.tensor(df[target], dtype=torch.float)" + ], + "metadata": { + "id": "P-8lKzdzLnGo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Define a model that has five input features and predicts a single value." + ], + "metadata": { + "id": "4mcNOez1MQZP" + } + }, + { + "cell_type": "code", + "source": [ + "def build_model(n_inputs, n_outputs):\n", + " \"\"\"build_model builds and returns a model that takes\n", + " `n_inputs` features and predicts `n_outputs` value\"\"\"\n", + " return torch.nn.Sequential(\n", + " torch.nn.Linear(n_inputs, 8),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(8, 16),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(16, n_outputs),\n", + " torch.nn.Sigmoid())" + ], + "metadata": { + "id": "YvdPNlzoMTtl" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Train the model." + ], + "metadata": { + "id": "GaLBmcvrMOWy" + } + }, + { + "cell_type": "code", + "source": [ + "model = build_model(n_inputs=5, n_outputs=1)\n", + "\n", + "loss_fn = torch.nn.BCELoss()\n", + "optimizer = torch.optim.Adam(model.parameters())\n", + "\n", + "for epoch in range(1000):\n", + " print(f'Epoch {epoch}: ---')\n", + " optimizer.zero_grad()\n", + " for i in range(len(X)):\n", + " pred = model(X[i])\n", + " loss = loss_fn(pred, Y[i].unsqueeze(0))\n", + " loss.backward()\n", + " optimizer.step()" + ], + "metadata": { + "id": "0XqctMiPMaim" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Save the model to the `STATE_DICT_PATH` variable." + ], + "metadata": { + "id": "m7MD6RwGMdyU" + } + }, + { + "cell_type": "code", + "source": [ + "STATE_DICT_PATH = './model.pth'\n", + "torch.save(model.state_dict(), STATE_DICT_PATH)" + ], + "metadata": { + "id": "Q9WIjw53MgcR" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Publish messages to Pub/Sub\n", + "Create the Pub/Sub topic and subscription to use for data streaming." + ], + "metadata": { + "id": "CJVYA0N0MnZS" + } + }, + { + "cell_type": "code", + "source": [ + "# Replace with the name of your Pub/Sub topic.\n", + "TOPIC = \"\"\n", + "\n", + "# Replace with the subscription for your topic.\n", + "SUBSCRIPTION = \"\"" + ], + "metadata": { + "id": "0uwZz_ijyzL8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from google.api_core.exceptions import AlreadyExists\n", + "\n", + "publisher = pubsub_v1.PublisherClient()\n", + "topic_path = publisher.topic_path(PROJECT_ID, TOPIC)\n", + "try:\n", + " topic = publisher.create_topic(request={\"name\": topic_path})\n", + " print(f\"Created topic: {topic.name}\")\n", + "except AlreadyExists:\n", + " print(f\"Topic {topic_path} already exists.\")\n", + "\n", + "subscriber = pubsub_v1.SubscriberClient()\n", + "subscription_path = subscriber.subscription_path(PROJECT_ID, SUBSCRIPTION)\n", + "try:\n", + " subscription = subscriber.create_subscription(\n", + " request={\"name\": subscription_path, \"topic\": topic_path}\n", + " )\n", + " print(f\"Created subscription: {subscription.name}\")\n", + "except AlreadyExists:\n", + " print(f\"Subscription {subscription_path} already exists.\")" + ], + "metadata": { + "id": "hIgsCWIozdDu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "Use the Pub/Sub Python client to publish messages." + ], + "metadata": { + "id": "VqUaFm_yywjU" + } + }, + { + "cell_type": "code", + "source": [ + "messages = [\n", + " {'customer_id': i}\n", + " for i in range(1, 6)\n", + "]\n", + "\n", + "for message in messages:\n", + " data = json.dumps(message).encode('utf-8')\n", + " publish_future = publisher.publish(topic_path, data)" + ], + "metadata": { + "id": "fOq1uNXvMku-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Use the BigQuery enrichment handler\n", + "\n", + "The [`BigQueryEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigquery.html#apache_beam.transforms.enrichment_handlers.bigquery.BigQueryEnrichmentHandler) is a built-in handler included in the Apache Beam SDK versions 2.57.0 and later.\n", + "\n", + "Configure the `BigQueryEnrichmentHandler` handler with the following parameters.\n", + "\n", + "### Required parameters\n", + "\n", + "The following parameters are required.\n", + "\n", + "* `project` (str): The Google Cloud project ID for the BigQuery table\n", + "\n", + "You must also provide one of the following combinations:\n", + "* `table_name`, `row_restriction_template`, and `fields`\n", + "* `table_name`, `row_restriction_template`, and `condition_value_fn`\n", + "* `query_fn`\n", + "\n", + "### Optional parameters\n", + "\n", + "The following parameters are optional.\n", + "\n", + "* `table_name` (str): The fully qualified BigQuery table name in the format `project.dataset.table`\n", + "* `row_restriction_template` (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data\n", + "* `fields` (Optional[List[str]]): A list of field names present in the input `beam.Row`. These fields names are used to construct the `WHERE` clause if `condition_value_fn` is not provided.\n", + "* `column_names` (Optional[List[str]]): The names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected.\n", + "* `condition_value_fn` (Optional[Callable[[beam.Row], List[Any]]]): A function that takes a `beam.Row` and returns a list of values to populate in the placeholder `{}` of the `WHERE` clause in the query\n", + "* `query_fn` (Optional[Callable[[beam.Row], str]]): A function that takes a `beam.Row` and returns a complete BigQuery SQL query string\n", + "* `min_batch_size` (int): The minimum number of rows to batch together when querying BigQuery. Defaults to `1` if `query_fn` is not specified.\n", + "* `max_batch_size` (int): The maximum number of rows to batch together. Defaults to `10,000` if `query_fn` is not specified.\n", + "\n", + "### Parameter requirements\n", + "\n", + "When you use parameters, consider the following requirements.\n", + "\n", + "* You can't define the `min_batch_size` and `max_batch_size` parameters if you provide the `query_fn` parameter.\n", + "* You must provide either the `fields` parameter or the `condition_value_fn` parameter for query construction if you don't provide the `query_fn` parameter.\n", + "* You must grant the appropriate permissions to access BigQuery.\n", + "\n", + "### Create handlers\n", + "\n", + "In this example, you create two handlers:\n", + "\n", + "* One for customer data that specifies `table_name` and `row_restriction_template`\n", + "* One for for usage data that uses a custom aggregation query by using the `query_fn` function\n", + "\n", + "These handlers are used in the Enrichment transforms in this pipeline to fetch and join data from BigQuery with the streaming data." + ], + "metadata": { + "id": "giXOGruKM8ZL" + } + }, + { + "cell_type": "code", + "source": [ + "user_data_handler = BigQueryEnrichmentHandler(\n", + " project=PROJECT_ID,\n", + " table_name=f\"`{CUSTOMERS_TABLE_ID}`\",\n", + " row_restriction_template='customer_id = {}',\n", + " fields=['customer_id']\n", + ")\n", + "\n", + "# Define the SQL query for usage data aggregation.\n", + "usage_data_query_template = f\"\"\"\n", + "WITH monthly_aggregates AS (\n", + " SELECT\n", + " customer_id,\n", + " DATE_TRUNC(date, MONTH) as month,\n", + " SUM(calls_made) as total_calls,\n", + " SUM(data_usage_gb) as total_data_usage_gb\n", + " FROM\n", + " `{USAGE_TABLE_ID}`\n", + " WHERE\n", + " customer_id = @customer_id\n", + " GROUP BY\n", + " customer_id, month\n", + ")\n", + "SELECT\n", + " customer_id,\n", + " AVG(total_calls) as avg_monthly_calls,\n", + " AVG(total_data_usage_gb) as avg_monthly_data_usage_gb\n", + "FROM\n", + " monthly_aggregates\n", + "GROUP BY\n", + " customer_id\n", + "\"\"\"\n", + "\n", + "def usage_data_query_fn(row: beam.Row) -> str:\n", + " return usage_data_query_template.replace('@customer_id', str(row.customer_id))\n", + "\n", + "usage_data_handler = BigQueryEnrichmentHandler(\n", + " project=PROJECT_ID,\n", + " query_fn=usage_data_query_fn\n", + ")" + ], + "metadata": { + "id": "C8XLmBDeMyrB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "In this example:\n", + "1. The `user_data_handler` handler uses the `table_name`, `row_restriction_template`, and `fields` parameter combination to fetch customer data.\n", + "2. The `usage_data_handler` handler uses the `query_fn` parameter to execute a more complex query that aggregates usage data." + ], + "metadata": { + "id": "3oPYypvmPiyg" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Use the `PytorchModelHandlerTensor` interface to run inference\n", + "\n", + "Define functions to convert enriched data to the tensor format for the model." + ], + "metadata": { + "id": "ksON9uOBQbZm" + } + }, + { + "cell_type": "code", + "source": [ + "def convert_row_to_tensor(customer_data):\n", + " import pandas as pd\n", + " customer_df = pd.DataFrame([customer_data[1].as_dict()])\n", + " customer_df['plan'] = plan_encoder.transform(customer_df['plan'])\n", + " return (customer_data[0], torch.tensor(customer_df[features].values, dtype=torch.float))\n", + "\n", + "keyed_model_handler = KeyedModelHandler(PytorchModelHandlerTensor(\n", + " state_dict_path=STATE_DICT_PATH,\n", + " model_class=build_model,\n", + " model_params={'n_inputs':5, 'n_outputs':1}\n", + ")).with_preprocess_fn(convert_row_to_tensor)" + ], + "metadata": { + "id": "XgPontIVP0Cv" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Define a `DoFn` to format the output." + ], + "metadata": { + "id": "O9e7ddgGQxh2" + } + }, + { + "cell_type": "code", + "source": [ + "class PostProcessor(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " print('Customer %d churn risk: %s' % (element[0], \"High\" if element[1].inference[0].item() > 0.5 else \"Low\"))" + ], + "metadata": { + "id": "NMj0V5VyQukk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run the pipeline\n", + "\n", + "Configure the pipeline to run in streaming mode." + ], + "metadata": { + "id": "-N3a1s2FQ66z" + } + }, + { + "cell_type": "code", + "source": [ + "options = pipeline_options.PipelineOptions()\n", + "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True" + ], + "metadata": { + "id": "rgJeV-jWQ4wo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`." + ], + "metadata": { + "id": "NRljYVR5RCMi" + } + }, + { + "cell_type": "code", + "source": [ + "class DecodeBytes(beam.DoFn):\n", + " \"\"\"\n", + " The DecodeBytes `DoFn` converts the data read from Pub/Sub to `beam.Row`.\n", + " First, decode the encoded string. Convert the output to\n", + " a `dict` with `json.loads()`, which is used to create a `beam.Row`.\n", + " \"\"\"\n", + " def process(self, element, *args, **kwargs):\n", + " element_dict = json.loads(element.decode('utf-8'))\n", + " yield beam.Row(**element_dict)" + ], + "metadata": { + "id": "Bb-e3yjtQ2iU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Use the following code to run the pipeline.\n", + "\n", + "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run." + ], + "metadata": { + "id": "Q1HV8wH-RIbj" + } + }, + { + "cell_type": "code", + "source": [ + "with beam.Pipeline(options=options) as p:\n", + " _ = (p\n", + " | \"Read from Pub/Sub\" >> beam.io.ReadFromPubSub(subscription=f\"projects/{PROJECT_ID}/subscriptions/{SUBSCRIPTION}\")\n", + " | \"ConvertToRow\" >> beam.ParDo(DecodeBytes())\n", + " | \"Enrich with customer data\" >> Enrichment(user_data_handler)\n", + " | \"Enrich with usage data\" >> Enrichment(usage_data_handler)\n", + " | \"Key data\" >> beam.Map(lambda x: (x.customer_id, x))\n", + " | \"RunInference\" >> RunInference(keyed_model_handler)\n", + " | \"Format Output\" >> beam.ParDo(PostProcessor())\n", + " )" + ], + "metadata": { + "id": "y6HBH8yoRFp2" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 06692ca1ed32f3815379d561fc61e120f01eaa4e Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 15 Oct 2024 15:15:02 -0400 Subject: [PATCH 31/82] Revert disabled Gradle cache in #32751 (#32771) --- sdks/java/expansion-service/build.gradle | 4 ---- sdks/java/extensions/sql/expansion-service/build.gradle | 4 ---- sdks/java/io/expansion-service/build.gradle | 1 - sdks/python/apache_beam/yaml/yaml_provider.py | 2 +- 4 files changed, 1 insertion(+), 10 deletions(-) diff --git a/sdks/java/expansion-service/build.gradle b/sdks/java/expansion-service/build.gradle index a25583870acf..4dd8c8968ed9 100644 --- a/sdks/java/expansion-service/build.gradle +++ b/sdks/java/expansion-service/build.gradle @@ -57,7 +57,3 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } - -compileJava { - outputs.upToDateWhen { false } -} \ No newline at end of file diff --git a/sdks/java/extensions/sql/expansion-service/build.gradle b/sdks/java/extensions/sql/expansion-service/build.gradle index b8d78e4e1bb9..b6963cf7547b 100644 --- a/sdks/java/extensions/sql/expansion-service/build.gradle +++ b/sdks/java/extensions/sql/expansion-service/build.gradle @@ -46,7 +46,3 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } - -shadowJar { - outputs.upToDateWhen { false } -} \ No newline at end of file diff --git a/sdks/java/io/expansion-service/build.gradle b/sdks/java/io/expansion-service/build.gradle index cc8eccf98997..8b817163ae39 100644 --- a/sdks/java/io/expansion-service/build.gradle +++ b/sdks/java/io/expansion-service/build.gradle @@ -35,7 +35,6 @@ configurations.runtimeClasspath { shadowJar { mergeServiceFiles() - outputs.upToDateWhen { false } } description = "Apache Beam :: SDKs :: Java :: IO :: Expansion Service" diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index c2cba936abce..ef2316f51f0e 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -117,7 +117,7 @@ def affinity(self, other: "Provider"): (e.g. to encourage fusion). """ # TODO(yaml): This is a very rough heuristic. Consider doing better. - # E.g. we could look at the the expected environments themselves. + # E.g. we could look at the expected environments themselves. # Possibly, we could provide multiple expansions and have the runner itself # choose the actual implementation based on fusion (and other) criteria. a = self.underlying_provider() From e52868c29f9d6cae3c91aedff9814f90de241b36 Mon Sep 17 00:00:00 2001 From: pablo rodriguez defino Date: Tue, 15 Oct 2024 12:28:22 -0700 Subject: [PATCH 32/82] Enable BigQuery CDC configuration for Python BigQuery sink (#32529) * include CDC configuration on the storage write transform provider * adding the primary key configuration for CDC and tests * fixing List.of references to use ImmutableList * fixing test, missing calling the cdc info row builder() method * fix test, add config validations * added the xlang params to storage write python wrapper * adding missing comma * shortening property name * changing xlang config property * set use cdc schema property as nullable, added safe retrieval method * fixes property name reference and argument type definition * python format fix * adding xlang IT with BQ * adding missing primary key column to test * python format fix * format xlang test * more format xlang test fixes * and more format xlang test fixes * adding missing import * missing self reference * enabled create if needed functionality for CDC python integration, implemented table constraint support on the bigquery fake dataset services * Update bigquery.py * triggering the xlang tests * fixing lint * addressing few comments * cdc info is added after row transformation now * remove not used param * removed typing information for callable * adding test for cdc using dicts as input and cdc write callable * simplifying the xlang configuration from python perspective, will add callable on a future PR * spotless apply * wrong property passed to xlang builder * missing self * fixing xlang it * fixes wrong property reference * change cdc xlang test to use beam.io.WriteToBigQuery * force another build * modifying comment to trigger build. * addressing PR comments, included new dicts based test for xlang python tests, included the CDC configurations into the existing RowDynamicDestinations object, improved error message for mutation information schema checks. --- ..._PostCommit_Python_Xlang_Gcp_Dataflow.json | 3 +- ...am_PostCommit_Python_Xlang_Gcp_Direct.json | 2 +- .../beam/sdk/io/gcp/bigquery/BigQueryIO.java | 8 +- ...torageWriteApiSchemaTransformProvider.java | 122 +++++++++++++++++- .../sdk/io/gcp/testing/TableContainer.java | 32 ++++- ...geWriteApiSchemaTransformProviderTest.java | 115 ++++++++++++++++- .../io/external/xlang_bigqueryio_it_test.py | 122 ++++++++++++++++++ sdks/python/apache_beam/io/gcp/bigquery.py | 30 ++++- 8 files changed, 414 insertions(+), 20 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index 6b3a9dc134ee..27c1f3ae26cd 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,5 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index 88dfa2c26348..84bf90bd4121 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -2259,6 +2259,7 @@ public static Write applyRowMutations() { .withFormatFunction(RowMutation::getTableRow) .withRowMutationInformationFn(RowMutation::getMutationInformation); } + /** * A {@link PTransform} that writes a {@link PCollection} containing {@link GenericRecord * GenericRecords} to a BigQuery table. @@ -2367,8 +2368,10 @@ public enum Method { abstract WriteDisposition getWriteDisposition(); abstract Set getSchemaUpdateOptions(); + /** Table description. Default is empty. */ abstract @Nullable String getTableDescription(); + /** An option to indicate if table validation is desired. Default is true. */ abstract boolean getValidate(); @@ -3455,7 +3458,10 @@ && getStorageApiTriggeringFrequency(bqOptions) != null) { LOG.error("The Storage API sink does not support the WRITE_TRUNCATE write disposition."); } if (getRowMutationInformationFn() != null) { - checkArgument(getMethod() == Method.STORAGE_API_AT_LEAST_ONCE); + checkArgument( + getMethod() == Method.STORAGE_API_AT_LEAST_ONCE, + "When using row updates on BigQuery, StorageWrite API should execute using" + + " \"at least once\" mode."); checkArgument( getCreateDisposition() == CreateDisposition.CREATE_NEVER || getPrimaryKey() != null, "If specifying CREATE_IF_NEEDED along with row updates, a primary key needs to be specified"); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index 980d783ec43c..c1c06fc592f4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -20,6 +20,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import com.google.api.services.bigquery.model.TableConstraints; import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; @@ -27,6 +28,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import javax.annotation.Nullable; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; @@ -37,6 +39,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryStorageApiInsertError; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; +import org.apache.beam.sdk.io.gcp.bigquery.RowMutationInformation; import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; @@ -87,6 +90,14 @@ public class BigQueryStorageWriteApiSchemaTransformProvider private static final String FAILED_ROWS_WITH_ERRORS_TAG = "FailedRowsWithErrors"; // magic string that tells us to write to dynamic destinations protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; + protected static final String ROW_PROPERTY_MUTATION_INFO = "row_mutation_info"; + protected static final String ROW_PROPERTY_MUTATION_TYPE = "mutation_type"; + protected static final String ROW_PROPERTY_MUTATION_SQN = "change_sequence_number"; + protected static final Schema ROW_SCHEMA_MUTATION_INFO = + Schema.builder() + .addStringField("mutation_type") + .addStringField("change_sequence_number") + .build(); @Override protected SchemaTransform from( @@ -257,6 +268,20 @@ public static Builder builder() { @Nullable public abstract ErrorHandling getErrorHandling(); + @SchemaFieldDescription( + "This option enables the use of BigQuery CDC functionality. The expected PCollection" + + " should contain Beam Rows with a schema wrapping the record to be inserted and" + + " adding the CDC info similar to: {row_mutation_info: {mutation_type:\"...\", " + + "change_sequence_number:\"...\"}, record: {...}}") + @Nullable + public abstract Boolean getUseCdcWrites(); + + @SchemaFieldDescription( + "If CREATE_IF_NEEDED disposition is set, BigQuery table(s) will be created with this" + + " columns as primary key. Required when CDC writes are enabled with CREATE_IF_NEEDED.") + @Nullable + public abstract List getPrimaryKey(); + /** Builder for {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. */ @AutoValue.Builder public abstract static class Builder { @@ -277,6 +302,10 @@ public abstract static class Builder { public abstract Builder setErrorHandling(ErrorHandling errorHandling); + public abstract Builder setUseCdcWrites(Boolean cdcWrites); + + public abstract Builder setPrimaryKey(List pkColumns); + /** Builds a {@link BigQueryStorageWriteApiSchemaTransformConfiguration} instance. */ public abstract BigQueryStorageWriteApiSchemaTransformProvider .BigQueryStorageWriteApiSchemaTransformConfiguration @@ -343,15 +372,27 @@ public void process(ProcessContext c) {} } private static class RowDynamicDestinations extends DynamicDestinations { - Schema schema; + final Schema schema; + final String fixedDestination; + final List primaryKey; RowDynamicDestinations(Schema schema) { this.schema = schema; + this.fixedDestination = null; + this.primaryKey = null; + } + + public RowDynamicDestinations( + Schema schema, String fixedDestination, List primaryKey) { + this.schema = schema; + this.fixedDestination = fixedDestination; + this.primaryKey = primaryKey; } @Override public String getDestination(ValueInSingleWindow element) { - return element.getValue().getString("destination"); + return Optional.ofNullable(fixedDestination) + .orElseGet(() -> element.getValue().getString("destination")); } @Override @@ -363,6 +404,17 @@ public TableDestination getTable(String destination) { public TableSchema getSchema(String destination) { return BigQueryUtils.toTableSchema(schema); } + + @Override + public TableConstraints getTableConstraints(String destination) { + return Optional.ofNullable(this.primaryKey) + .filter(pk -> !pk.isEmpty()) + .map( + pk -> + new TableConstraints() + .setPrimaryKey(new TableConstraints.PrimaryKey().setColumns(pk))) + .orElse(null); + } } @Override @@ -453,6 +505,13 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } } + void validateDynamicDestinationsExpectedSchema(Schema schema) { + checkArgument( + schema.getFieldNames().containsAll(Arrays.asList("destination", "record")), + "When writing to dynamic destinations, we expect Row Schema with a " + + "\"destination\" string field and a \"record\" Row field."); + } + BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { Method writeMethod = configuration.getUseAtLeastOnceSemantics() != null @@ -466,11 +525,11 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { .withFormatFunction(BigQueryUtils.toTableRow()) .withWriteDisposition(WriteDisposition.WRITE_APPEND); - if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { - checkArgument( - schema.getFieldNames().equals(Arrays.asList("destination", "record")), - "When writing to dynamic destinations, we expect Row Schema with a " - + "\"destination\" string field and a \"record\" Row field."); + // in case CDC writes are configured we validate and include them in the configuration + if (Optional.ofNullable(configuration.getUseCdcWrites()).orElse(false)) { + write = validateAndIncludeCDCInformation(write, schema); + } else if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { + validateDynamicDestinationsExpectedSchema(schema); write = write .to(new RowDynamicDestinations(schema.getField("record").getType().getRowSchema())) @@ -485,6 +544,7 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { configuration.getCreateDisposition().toUpperCase()); write = write.withCreateDisposition(createDisposition); } + if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { WriteDisposition writeDisposition = BigQueryStorageWriteApiSchemaTransformConfiguration.WRITE_DISPOSITIONS.get( @@ -498,5 +558,53 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { return write; } + + BigQueryIO.Write validateAndIncludeCDCInformation( + BigQueryIO.Write write, Schema schema) { + checkArgument( + schema.getFieldNames().containsAll(Arrays.asList(ROW_PROPERTY_MUTATION_INFO, "record")), + "When writing using CDC functionality, we expect Row Schema with a " + + "\"" + + ROW_PROPERTY_MUTATION_INFO + + "\" Row field and a \"record\" Row field."); + + Schema rowSchema = schema.getField(ROW_PROPERTY_MUTATION_INFO).getType().getRowSchema(); + + checkArgument( + rowSchema.equals(ROW_SCHEMA_MUTATION_INFO), + "When writing using CDC functionality, we expect a \"" + + ROW_PROPERTY_MUTATION_INFO + + "\" field of Row type with schema:\n" + + ROW_SCHEMA_MUTATION_INFO.toString() + + "\n" + + "Received \"" + + ROW_PROPERTY_MUTATION_INFO + + "\" field with schema:\n" + + rowSchema.toString()); + + String tableDestination = null; + + if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { + validateDynamicDestinationsExpectedSchema(schema); + } else { + tableDestination = configuration.getTable(); + } + + return write + .to( + new RowDynamicDestinations( + schema.getField("record").getType().getRowSchema(), + tableDestination, + configuration.getPrimaryKey())) + .withFormatFunction(row -> BigQueryUtils.toTableRow(row.getRow("record"))) + .withPrimaryKey(configuration.getPrimaryKey()) + .withRowMutationInformationFn( + row -> + RowMutationInformation.of( + RowMutationInformation.MutationType.valueOf( + row.getRow(ROW_PROPERTY_MUTATION_INFO) + .getString(ROW_PROPERTY_MUTATION_TYPE)), + row.getRow(ROW_PROPERTY_MUTATION_INFO).getString(ROW_PROPERTY_MUTATION_SQN))); + } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java index b50aa4d32d76..b44b9596cc12 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java @@ -18,11 +18,13 @@ package org.apache.beam.sdk.io.gcp.testing; import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableConstraints; import com.google.api.services.bigquery.model.TableRow; import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.beam.sdk.io.gcp.bigquery.TableRowJsonCoder; @@ -51,12 +53,24 @@ class TableContainer { this.keyedRows = Maps.newHashMap(); this.ids = new ArrayList<>(); this.sizeBytes = 0L; + // extract primary key information from Table if present + List pkColumns = primaryKeyColumns(table); + this.primaryKeyColumns = pkColumns; + this.primaryKeyColumnIndices = primaryColumnFieldIndices(pkColumns, table); } - // Only top-level columns supported. - void setPrimaryKeyColumns(List primaryKeyColumns) { - this.primaryKeyColumns = primaryKeyColumns; + static @Nullable List primaryKeyColumns(Table table) { + return Optional.ofNullable(table.getTableConstraints()) + .flatMap(constraints -> Optional.ofNullable(constraints.getPrimaryKey())) + .map(TableConstraints.PrimaryKey::getColumns) + .orElse(null); + } + static @Nullable List primaryColumnFieldIndices( + @Nullable List primaryKeyColumns, Table table) { + if (primaryKeyColumns == null) { + return null; + } Map indices = IntStream.range(0, table.getSchema().getFields().size()) .boxed() @@ -65,7 +79,13 @@ void setPrimaryKeyColumns(List primaryKeyColumns) { for (String columnName : primaryKeyColumns) { primaryKeyColumnIndices.add(Preconditions.checkStateNotNull(indices.get(columnName))); } - this.primaryKeyColumnIndices = primaryKeyColumnIndices; + return primaryKeyColumnIndices; + } + + // Only top-level columns supported. + void setPrimaryKeyColumns(List primaryKeyColumns) { + this.primaryKeyColumns = primaryKeyColumns; + this.primaryKeyColumnIndices = primaryColumnFieldIndices(primaryKeyColumns, table); } @Nullable @@ -80,7 +100,7 @@ List getPrimaryKey(TableRow tableRow) { .stream() .map(cell -> Preconditions.checkStateNotNull(cell.get("v"))) .collect(Collectors.toList()); - ; + return Preconditions.checkStateNotNull(primaryKeyColumnIndices).stream() .map(cellValues::get) .collect(Collectors.toList()); @@ -91,7 +111,7 @@ List getPrimaryKey(TableRow tableRow) { long addRow(TableRow row, String id) { List primaryKey = getPrimaryKey(row); - if (primaryKey != null) { + if (primaryKey != null && !primaryKey.isEmpty()) { if (keyedRows.putIfAbsent(primaryKey, row) != null) { throw new RuntimeException( "Primary key validation error! Multiple inserts with the same primary key."); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index 64ea0b11d1b9..87ba2961461a 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -30,6 +30,8 @@ import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform; @@ -54,6 +56,7 @@ import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -221,6 +224,117 @@ public void testWriteToDynamicDestinations() throws Exception { fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_3").get(0))); } + List createCDCUpsertRows(List rows, boolean dynamicDestination, String tablePrefix) { + + Schema.Builder schemaBuilder = + Schema.builder() + .addRowField("record", SCHEMA) + .addRowField( + BigQueryStorageWriteApiSchemaTransformProvider.ROW_PROPERTY_MUTATION_INFO, + BigQueryStorageWriteApiSchemaTransformProvider.ROW_SCHEMA_MUTATION_INFO); + + if (dynamicDestination) { + schemaBuilder = schemaBuilder.addStringField("destination"); + } + + Schema schemaWithCDC = schemaBuilder.build(); + return IntStream.range(0, rows.size()) + .mapToObj( + idx -> { + Row row = rows.get(idx); + Row.FieldValueBuilder rowBuilder = + Row.withSchema(schemaWithCDC) + .withFieldValue( + BigQueryStorageWriteApiSchemaTransformProvider.ROW_PROPERTY_MUTATION_INFO, + Row.withSchema( + BigQueryStorageWriteApiSchemaTransformProvider + .ROW_SCHEMA_MUTATION_INFO) + .withFieldValue( + BigQueryStorageWriteApiSchemaTransformProvider + .ROW_PROPERTY_MUTATION_TYPE, + "UPSERT") + .withFieldValue( + BigQueryStorageWriteApiSchemaTransformProvider + .ROW_PROPERTY_MUTATION_SQN, + "AAA" + idx) + .build()) + .withFieldValue("record", row); + if (dynamicDestination) { + rowBuilder = + rowBuilder.withFieldValue("destination", tablePrefix + row.getInt64("number")); + } + return rowBuilder.build(); + }) + .collect(Collectors.toList()); + } + + @Test + public void testCDCWrites() throws Exception { + String tableSpec = "project:dataset.cdc_write"; + List primaryKeyColumns = ImmutableList.of("name"); + + BigQueryStorageWriteApiSchemaTransformConfiguration config = + BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + .setUseAtLeastOnceSemantics(true) + .setTable(tableSpec) + .setUseCdcWrites(true) + .setPrimaryKey(primaryKeyColumns) + .build(); + + List rowsDuplicated = + Stream.concat(ROWS.stream(), ROWS.stream()).collect(Collectors.toList()); + + runWithConfig(config, createCDCUpsertRows(rowsDuplicated, false, "")); + p.run().waitUntilFinish(); + + assertTrue( + rowEquals( + rowsDuplicated.get(3), + fakeDatasetService.getAllRows("project", "dataset", "cdc_write").get(0))); + assertTrue( + rowEquals( + rowsDuplicated.get(4), + fakeDatasetService.getAllRows("project", "dataset", "cdc_write").get(1))); + assertTrue( + rowEquals( + rowsDuplicated.get(5), + fakeDatasetService.getAllRows("project", "dataset", "cdc_write").get(2))); + } + + @Test + public void testCDCWriteToDynamicDestinations() throws Exception { + List primaryKeyColumns = ImmutableList.of("name"); + String dynamic = BigQueryStorageWriteApiSchemaTransformProvider.DYNAMIC_DESTINATIONS; + BigQueryStorageWriteApiSchemaTransformConfiguration config = + BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + .setUseAtLeastOnceSemantics(true) + .setTable(dynamic) + .setUseCdcWrites(true) + .setPrimaryKey(primaryKeyColumns) + .build(); + + String baseTableSpec = "project:dataset.cdc_dynamic_write_"; + + List rowsDuplicated = + Stream.concat(ROWS.stream(), ROWS.stream()).collect(Collectors.toList()); + + runWithConfig(config, createCDCUpsertRows(rowsDuplicated, true, baseTableSpec)); + p.run().waitUntilFinish(); + + assertTrue( + rowEquals( + rowsDuplicated.get(3), + fakeDatasetService.getAllRows("project", "dataset", "cdc_dynamic_write_1").get(0))); + assertTrue( + rowEquals( + rowsDuplicated.get(4), + fakeDatasetService.getAllRows("project", "dataset", "cdc_dynamic_write_2").get(0))); + assertTrue( + rowEquals( + rowsDuplicated.get(5), + fakeDatasetService.getAllRows("project", "dataset", "cdc_dynamic_write_3").get(0))); + } + @Test public void testInputElementCount() throws Exception { String tableSpec = "project:dataset.input_count"; @@ -292,7 +406,6 @@ public void testFailedRows() throws Exception { MapElements.into(TypeDescriptors.rows()) .via((rowAndError) -> rowAndError.getValue("failed_row"))) .setRowSchema(SCHEMA); - ; PAssert.that(failedRows).containsInAnyOrder(expectedFailedRows); p.run().waitUntilFinish(); diff --git a/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py b/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py index cfbb411b4e5f..7f3a16e02aa3 100644 --- a/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py @@ -245,6 +245,128 @@ def test_write_with_beam_rows(self): | StorageWriteToBigQuery(table=table_id)) hamcrest_assert(p, bq_matcher) + def test_write_with_beam_rows_cdc(self): + table = 'write_with_beam_rows_cdc' + table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table) + + expected_data_on_bq = [ + # (name, value) + { + "name": "cdc_test", + "value": 5, + } + ] + + rows_with_cdc = [ + beam.Row( + row_mutation_info=beam.Row( + mutation_type="UPSERT", change_sequence_number="AAA/2"), + record=beam.Row(name="cdc_test", value=5)), + beam.Row( + row_mutation_info=beam.Row( + mutation_type="UPSERT", change_sequence_number="AAA/1"), + record=beam.Row(name="cdc_test", value=3)) + ] + + bq_matcher = BigqueryFullResultMatcher( + project=self.project, + query="SELECT * FROM {}.{}".format(self.dataset_id, table), + data=self.parse_expected_data(expected_data_on_bq)) + + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create(rows_with_cdc) + | beam.io.WriteToBigQuery( + table=table_id, + method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API, + use_at_least_once=True, + use_cdc_writes=True, + primary_key=["name"])) + hamcrest_assert(p, bq_matcher) + + def test_write_with_dicts_cdc(self): + table = 'write_with_dicts_cdc' + table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table) + + expected_data_on_bq = [ + # (name, value) + { + "name": "cdc_test", + "value": 5, + } + ] + + data_with_cdc = [ + # record: (name, value) + { + 'row_mutation_info': { + 'mutation_type': 'UPSERT', 'change_sequence_number': 'AAA/2' + }, + 'record': { + 'name': 'cdc_test', 'value': 5 + } + }, + { + 'row_mutation_info': { + 'mutation_type': 'UPSERT', 'change_sequence_number': 'AAA/1' + }, + 'record': { + 'name': 'cdc_test', 'value': 3 + } + } + ] + + schema = { + "fields": [ + # include both record and mutation info fields as part of the schema + { + "name": "row_mutation_info", + "type": "STRUCT", + "fields": [ + # setting both fields are required + { + "name": "mutation_type", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "change_sequence_number", + "type": "STRING", + "mode": "REQUIRED" + } + ] + }, + { + "name": "record", + "type": "STRUCT", + "fields": [{ + "name": "name", "type": "STRING" + }, { + "name": "value", "type": "INTEGER" + }] + } + ] + } + + bq_matcher = BigqueryFullResultMatcher( + project=self.project, + query="SELECT * FROM {}.{}".format(self.dataset_id, table), + data=self.parse_expected_data(expected_data_on_bq)) + + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create(data_with_cdc) + | beam.io.WriteToBigQuery( + table=table_id, + method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API, + use_at_least_once=True, + use_cdc_writes=True, + schema=schema, + primary_key=["name"])) + hamcrest_assert(p, bq_matcher) + def test_write_to_dynamic_destinations(self): base_table_spec = '{}.dynamic_dest_'.format(self.dataset_id) spec_with_project = '{}:{}'.format(self.project, base_table_spec) diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 2cb64742f26c..11e0d098b2f3 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -1930,6 +1930,8 @@ def __init__( load_job_project_id=None, max_insert_payload_size=MAX_INSERT_PAYLOAD_SIZE, num_streaming_keys=DEFAULT_SHARDS_PER_DESTINATION, + use_cdc_writes: bool = False, + primary_key: List[str] = None, expansion_service=None): """Initialize a WriteToBigQuery transform. @@ -2095,6 +2097,15 @@ def __init__( GCP expansion service. Used for STORAGE_WRITE_API method. max_insert_payload_size: The maximum byte size for a BigQuery legacy streaming insert payload. + use_cdc_writes: Configure the usage of CDC writes on BigQuery. + The argument can be used by passing True and the Beam Rows will be + sent as they are to the BigQuery sink which expects a 'record' + and 'row_mutation_info' properties. + Used for STORAGE_WRITE_API, working on 'at least once' mode. + primary_key: When using CDC write on BigQuery and + CREATE_IF_NEEDED mode for the underlying tables a list of column names + is required to be configured as the primary key. Used for + STORAGE_WRITE_API, working on 'at least once' mode. """ self._table = table self._dataset = dataset @@ -2136,6 +2147,8 @@ def __init__( self.load_job_project_id = load_job_project_id self._max_insert_payload_size = max_insert_payload_size self._num_streaming_keys = num_streaming_keys + self._use_cdc_writes = use_cdc_writes + self._primary_key = primary_key # Dict/schema methods were moved to bigquery_tools, but keep references # here for backward compatibility. @@ -2289,8 +2302,9 @@ def find_in_nested_dict(schema): use_at_least_once=self.use_at_least_once, with_auto_sharding=self.with_auto_sharding, num_storage_api_streams=self._num_storage_api_streams, + use_cdc_writes=self._use_cdc_writes, + primary_key=self._primary_key, expansion_service=self.expansion_service) - else: raise ValueError(f"Unsupported method {method_to_use}") @@ -2518,6 +2532,10 @@ class StorageWriteToBigQuery(PTransform): # fields for rows sent to Storage API with dynamic destinations DESTINATION = "destination" RECORD = "record" + # field names for rows sent to Storage API for CDC functionality + CDC_INFO = "row_mutation_info" + CDC_MUTATION_TYPE = "mutation_type" + CDC_SQN = "change_sequence_number" # magic string to tell Java that these rows are going to dynamic destinations DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS" @@ -2532,6 +2550,8 @@ def __init__( use_at_least_once=False, with_auto_sharding=False, num_storage_api_streams=0, + use_cdc_writes: bool = False, + primary_key: List[str] = None, expansion_service=None): self._table = table self._table_side_inputs = table_side_inputs @@ -2542,6 +2562,8 @@ def __init__( self._use_at_least_once = use_at_least_once self._with_auto_sharding = with_auto_sharding self._num_storage_api_streams = num_storage_api_streams + self._use_cdc_writes = use_cdc_writes + self._primary_key = primary_key self._expansion_service = expansion_service or BeamJarExpansionService( 'sdks:java:io:google-cloud-platform:expansion-service:build') @@ -2552,11 +2574,11 @@ def expand(self, input): is_rows = True except TypeError as exn: raise ValueError( - "A schema is required in order to prepare rows" + "A schema is required in order to prepare rows " "for writing with STORAGE_WRITE_API.") from exn elif callable(self._schema): raise NotImplementedError( - "Writing with dynamic schemas is not" + "Writing with dynamic schemas is not " "supported for this write method.") elif isinstance(self._schema, vp.ValueProvider): schema = self._schema.get() @@ -2624,6 +2646,8 @@ def expand(self, input): auto_sharding=self._with_auto_sharding, num_streams=self._num_storage_api_streams, use_at_least_once_semantics=self._use_at_least_once, + use_cdc_writes=self._use_cdc_writes, + primary_key=self._primary_key, error_handling={ 'output': StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS })) From e39e5d724c946901ba7065442781bd3457b4c4de Mon Sep 17 00:00:00 2001 From: pablo rodriguez defino Date: Tue, 15 Oct 2024 12:29:07 -0700 Subject: [PATCH 33/82] Update CHANGES.md (#32788) --- CHANGES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES.md b/CHANGES.md index 4e21e400e60d..f2b865cec236 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -67,6 +67,7 @@ * [Managed Iceberg] Support creating tables if needed ([#32686](https://github.com/apache/beam/pull/32686)) * [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) * [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) +* BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) ## New Features / Improvements From 0feaaa6c96b223fab520585786600104ab362a62 Mon Sep 17 00:00:00 2001 From: liferoad Date: Tue, 15 Oct 2024 16:43:09 -0400 Subject: [PATCH 34/82] Add the beam summit 2024 overview blog --- .../en/blog/beam-summit-2024-overview.md | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 website/www/site/content/en/blog/beam-summit-2024-overview.md diff --git a/website/www/site/content/en/blog/beam-summit-2024-overview.md b/website/www/site/content/en/blog/beam-summit-2024-overview.md new file mode 100644 index 000000000000..8d5eb209bea2 --- /dev/null +++ b/website/www/site/content/en/blog/beam-summit-2024-overview.md @@ -0,0 +1,59 @@ +--- +title: "Apache Beam Summit 2024: Unlocking the power of ML for data processing" +date: 2024-10-16 00:00:01 -0800 +categories: + - blog +aliases: + - /blog/2024/10/16/beam-summit-2024-overview.html +authors: + - liferoad + - damccorm + - rezarokni +--- + + +At the recently concluded [Beam Summit 2024](https://beamsummit.org/), a two-day event held from September 4 to 5, numerous captivating presentations showcased the potential of Beam to address a wide range of challenges, with an emphasis on machine learning (ML). These challenges included feature engineering, data enrichment, and model inference for large-scale distributed data. In all, the summit included [47 talks](https://beamsummit.org/sessions/2024/), with 16 focused specifically on ML use cases or features and many more touching on these topics. + +The talks displayed the breadth and diversity of the Beam community. Among the speakers and attendees, [23 countries](https://docs.google.com/presentation/d/1IJ1sExHzrzIFF5QXKWlcAuPdp7lKOepRQKl9BnfHxJw/edit#slide=id.g3058d3e2f5f_0_10) were represented. Attendees included Beam users, committers in the Beam project, Beam Google Summer of Code contributors, and data processing/machine learning experts. + +## User-friendly turnkey transforms for ML + +With the features recently added to Beam, Beam now offers a set of rich turn-key transforms for ML users that handle a wide range of ML-Ops tasks. These transforms include: + +* [RunInference](https://beam.apache.org/documentation/ml/overview/#prediction-and-inference): deploy ML models on CPUs and GPUs +* [Enrichment](https://beam.apache.org/documentation/ml/overview/#data-processing): enrich data for ML feature enhancements +* [MLTransform](https://beam.apache.org/documentation/ml/overview/#data-processing): transform data into ML features + +The Summit talks covering both how to use these features and how people are already using them. Highlights included: + +* A talk about [scaling autonomous driving at Cruise](https://beamsummit.org/slides/2024/ScalingAutonomousDrivingwithApacheBeam.pdf) +* Multiple talks about deploying LLMs for batch and streaming inference +* Three different talks about streaming processing for [RAG](https://cloud.google.com/use-cases/retrieval-augmented-generation) (including [a talk](https://www.youtube.com/watch?v=X_VzKQOcpC4) from one of Beam's Google Summer of Code contributors\!) + +## Beam YAML: Simplifying ML data processing + +Beam pipeline creation can be challenging and often requires learning concepts, managing dependencies, debugging, and maintaining code for ML tasks. To simplify the entry point, [Beam YAML](https://beam.apache.org/blog/beam-yaml-release/) introduces a declarative approach that uses YAML configuration files to create data processing pipelines. No coding is required. + +Beam Summit was the first opportunity that the Beam community had to show off some of the use cases of Beam YAML. It featured several talks about how Beam YAML is already a core part of many users' workflows at companies like [MavenCode](https://beamsummit.org/slides/2024/ALowCodeStructuredApproachtoDeployingApacheBeamMLWorkloadsonKubernetesusingBeamStack.pdf) and [ChartBoost](https://youtu.be/avSXvbScbW0). With Beam YAML, these companies are able to build configuration-based data processing systems, significantly lowering the bar for entry at their companies. + +## Prism: Provide a unified ML pipeline development framework for local and remote runner environments + +Beam provides a variety of support for portable runners, but developing a local pipeline has traditionally been painful. Local runners are often incomplete and incompatible with remote runners, such as DataflowRunner and FlinkRunner. + +At Beam Summit, Beam contributors introduced [the Prism local runner](https://youtu.be/R4iNwLBa3VQ) to the community. Prism greatly improves the local developer experience and reduces the gap between local and remote execution. In particular, when handling complicated ML tasks, Prism guarantees consistent runner behavior across these runners, a task that had previously lacked consistent support. + +## Summary + +Beam Summit 2024 showcased the tremendous potential of Apache Beam for addressing a wide range of data processing and machine learning challenges. We look forward to seeing even more innovative use cases and contributions in the future. + +To stay updated on the latest Beam developments and events, visit [the Apache Beam website](https://beam.apache.org/get-started/) and follow us on [social media](https://www.linkedin.com/company/apache-beam/). We encourage you to join [the Beam community](https://beam.apache.org/community/contact-us/) and [contribute to the project](https://beam.apache.org/contribute/). Together, let's unlock the full potential of Beam and shape the future of data processing and machine learning. \ No newline at end of file From a50f91c386c00940b08ef8a5e4d0817422ea230f Mon Sep 17 00:00:00 2001 From: reuvenlax Date: Tue, 15 Oct 2024 14:04:30 -0700 Subject: [PATCH 35/82] Merge pull request #32757: Schema inference parameterized types --- .../beam/sdk/schemas/AutoValueSchema.java | 8 +- .../schemas/FieldValueTypeInformation.java | 89 ++++++----- .../beam/sdk/schemas/JavaBeanSchema.java | 12 +- .../beam/sdk/schemas/JavaFieldSchema.java | 10 +- .../beam/sdk/schemas/SchemaProvider.java | 3 +- .../beam/sdk/schemas/SchemaRegistry.java | 39 ++--- .../transforms/providers/JavaRowUdf.java | 3 +- .../sdk/schemas/utils/AutoValueUtils.java | 20 ++- .../sdk/schemas/utils/ByteBuddyUtils.java | 53 ++++--- .../sdk/schemas/utils/ConvertHelpers.java | 6 +- .../beam/sdk/schemas/utils/JavaBeanUtils.java | 10 +- .../beam/sdk/schemas/utils/POJOUtils.java | 20 ++- .../beam/sdk/schemas/utils/ReflectUtils.java | 83 ++++++++-- .../schemas/utils/StaticSchemaInference.java | 91 +++++------ .../beam/sdk/schemas/AutoValueSchemaTest.java | 149 ++++++++++++++++++ .../beam/sdk/schemas/JavaBeanSchemaTest.java | 124 +++++++++++++++ .../beam/sdk/schemas/JavaFieldSchemaTest.java | 120 ++++++++++++++ .../sdk/schemas/utils/JavaBeanUtilsTest.java | 33 +++- .../beam/sdk/schemas/utils/POJOUtilsTest.java | 36 +++-- .../beam/sdk/schemas/utils/TestJavaBeans.java | 91 +++++++++++ .../beam/sdk/schemas/utils/TestPOJOs.java | 121 +++++++++++++- .../schemas/utils/AvroByteBuddyUtils.java | 6 +- .../avro/schemas/utils/AvroUtils.java | 10 +- .../protobuf/ProtoByteBuddyUtils.java | 4 +- .../protobuf/ProtoMessageSchema.java | 8 +- .../python/PythonExternalTransform.java | 4 +- .../beam/sdk/io/thrift/ThriftSchema.java | 5 +- 27 files changed, 961 insertions(+), 197 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index 5ccfe39b92af..c369eefeb65c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -19,8 +19,10 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.AutoValueUtils; @@ -61,8 +63,9 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -143,7 +146,8 @@ public SchemaUserTypeCreator schemaTypeCreator( @Override public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor, AbstractGetterTypeSupplier.INSTANCE); + typeDescriptor, AbstractGetterTypeSupplier.INSTANCE, boundTypes); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 750709192c08..64687e6d3381 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -24,10 +24,12 @@ import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Method; +import java.lang.reflect.Type; import java.util.Arrays; import java.util.Collections; import java.util.Map; import java.util.stream.Stream; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -44,6 +46,7 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) +@Internal public abstract class FieldValueTypeInformation implements Serializable { /** Optionally returns the field index. */ public abstract @Nullable Integer getNumber(); @@ -125,8 +128,10 @@ public static FieldValueTypeInformation forOneOf( .build(); } - public static FieldValueTypeInformation forField(Field field, int index) { - TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); + public static FieldValueTypeInformation forField( + Field field, int index, Map boundTypes) { + TypeDescriptor type = + TypeDescriptor.of(ReflectUtils.resolveType(field.getGenericType(), boundTypes)); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -134,9 +139,9 @@ public static FieldValueTypeInformation forField(Field field, int index) { .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field)) - .setMapKeyType(getMapKeyType(field)) - .setMapValueType(getMapValueType(field)) + .setElementType(getIterableComponentType(field, boundTypes)) + .setMapKeyType(getMapKeyType(field, boundTypes)) + .setMapValueType(getMapValueType(field, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -184,7 +189,8 @@ public static String getNameOverride( return fieldDescription.value(); } - public static FieldValueTypeInformation forGetter(Method method, int index) { + public static FieldValueTypeInformation forGetter( + Method method, int index, Map boundTypes) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -194,7 +200,8 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { throw new RuntimeException("Getter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericReturnType()); + TypeDescriptor type = + TypeDescriptor.of(ReflectUtils.resolveType(method.getGenericReturnType(), boundTypes)); boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -203,9 +210,9 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type)) - .setMapKeyType(getMapKeyType(type)) - .setMapValueType(getMapValueType(type)) + .setElementType(getIterableComponentType(type, boundTypes)) + .setMapKeyType(getMapKeyType(type, boundTypes)) + .setMapValueType(getMapValueType(type, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(method)) .build(); @@ -252,11 +259,13 @@ private static boolean isNullableAnnotation(Annotation annotation) { return annotation.annotationType().getSimpleName().equals("Nullable"); } - public static FieldValueTypeInformation forSetter(Method method) { - return forSetter(method, "set"); + public static FieldValueTypeInformation forSetter( + Method method, Map boundParameters) { + return forSetter(method, "set", boundParameters); } - public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { + public static FieldValueTypeInformation forSetter( + Method method, String setterPrefix, Map boundTypes) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -264,7 +273,9 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr throw new RuntimeException("Setter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericParameterTypes()[0]); + TypeDescriptor type = + TypeDescriptor.of( + ReflectUtils.resolveType(method.getGenericParameterTypes()[0], boundTypes)); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -272,9 +283,9 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type)) - .setMapKeyType(getMapKeyType(type)) - .setMapValueType(getMapValueType(type)) + .setElementType(getIterableComponentType(type, boundTypes)) + .setMapKeyType(getMapKeyType(type, boundTypes)) + .setMapValueType(getMapValueType(type, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -283,13 +294,15 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType(Field field) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType())); + private static FieldValueTypeInformation getIterableComponentType( + Field field, Map boundTypes) { + return getIterableComponentType(TypeDescriptor.of(field.getGenericType()), boundTypes); } - static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { + static @Nullable FieldValueTypeInformation getIterableComponentType( + TypeDescriptor valueType, Map boundTypes) { // TODO: Figure out nullable elements. - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType, boundTypes); if (componentType == null) { return null; } @@ -299,41 +312,43 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { .setNullable(false) .setType(componentType) .setRawType(componentType.getRawType()) - .setElementType(getIterableComponentType(componentType)) - .setMapKeyType(getMapKeyType(componentType)) - .setMapValueType(getMapValueType(componentType)) + .setElementType(getIterableComponentType(componentType, boundTypes)) + .setMapKeyType(getMapKeyType(componentType, boundTypes)) + .setMapValueType(getMapValueType(componentType, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .build(); } // If the Field is a map type, returns the key type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType())); + private static @Nullable FieldValueTypeInformation getMapKeyType( + Field field, Map boundTypes) { + return getMapKeyType(TypeDescriptor.of(field.getGenericType()), boundTypes); } private static @Nullable FieldValueTypeInformation getMapKeyType( - TypeDescriptor typeDescriptor) { - return getMapType(typeDescriptor, 0); + TypeDescriptor typeDescriptor, Map boundTypes) { + return getMapType(typeDescriptor, 0, boundTypes); } // If the Field is a map type, returns the value type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapValueType(Field field) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1); + private static @Nullable FieldValueTypeInformation getMapValueType( + Field field, Map boundTypes) { + return getMapType(TypeDescriptor.of(field.getGenericType()), 1, boundTypes); } private static @Nullable FieldValueTypeInformation getMapValueType( - TypeDescriptor typeDescriptor) { - return getMapType(typeDescriptor, 1); + TypeDescriptor typeDescriptor, Map boundTypes) { + return getMapType(typeDescriptor, 1, boundTypes); } // If the Field is a map type, returns the key or value type (0 is key type, 1 is value). // Otherwise returns a null reference. @SuppressWarnings("unchecked") private static @Nullable FieldValueTypeInformation getMapType( - TypeDescriptor valueType, int index) { - TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); + TypeDescriptor valueType, int index, Map boundTypes) { + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index, boundTypes); if (mapType == null) { return null; } @@ -342,9 +357,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { .setNullable(false) .setType(mapType) .setRawType(mapType.getRawType()) - .setElementType(getIterableComponentType(mapType)) - .setMapKeyType(getMapKeyType(mapType)) - .setMapValueType(getMapValueType(mapType)) + .setElementType(getIterableComponentType(mapType, boundTypes)) + .setMapKeyType(getMapKeyType(mapType, boundTypes)) + .setMapValueType(getMapValueType(mapType, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .build(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index a9cf01c52057..ad71576670bf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -19,8 +19,10 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -67,8 +69,9 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -111,10 +114,11 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(FieldValueTypeInformation::forSetter) + .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) .map( t -> { if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { @@ -156,8 +160,10 @@ public boolean equals(@Nullable Object obj) { @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); Schema schema = - JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE); + JavaBeanUtils.schemaFromJavaBeanClass( + typeDescriptor, GetterTypeSupplier.INSTANCE, boundTypes); // If there are no creator methods, then validate that we have setters for every field. // Otherwise, we will have no way of creating instances of the class. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index 21f07c47b47f..da0f59c8ee96 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -21,8 +21,10 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -62,9 +64,11 @@ public List get(TypeDescriptor typeDescriptor) { ReflectUtils.getFields(typeDescriptor.getRawType()).stream() .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); + List types = Lists.newArrayListWithCapacity(fields.size()); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i)); + types.add(FieldValueTypeInformation.forField(fields.get(i), i, boundTypes)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -111,7 +115,9 @@ private static void validateFieldNumbers(List types) @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); + return POJOUtils.schemaFromPojoClass( + typeDescriptor, JavaFieldTypeSupplier.INSTANCE, boundTypes); } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java index 37b4952e529c..b7e3cdf60c18 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java @@ -38,8 +38,7 @@ public interface SchemaProvider extends Serializable { * Given a type, return a function that converts that type to a {@link Row} object If no schema * exists, returns null. */ - @Nullable - SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); + @Nullable SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); /** * Given a type, returns a function that converts from a {@link Row} object to that type. If no diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java index 679a1fcf54fc..5d8b7aab6193 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java @@ -76,13 +76,12 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid providers.put(typeDescriptor, schemaProvider); } - @Override - public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { + private @Nullable SchemaProvider schemaProviderFor(TypeDescriptor typeDescriptor) { TypeDescriptor type = typeDescriptor; do { SchemaProvider schemaProvider = providers.get(type); if (schemaProvider != null) { - return schemaProvider.schemaFor(type); + return schemaProvider; } Class superClass = type.getRawType().getSuperclass(); if (superClass == null || superClass.equals(Object.class)) { @@ -92,38 +91,24 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid } while (true); } + @Override + public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { + @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); + return schemaProvider != null ? schemaProvider.schemaFor(typeDescriptor) : null; + } + @Override public @Nullable SerializableFunction toRowFunction( TypeDescriptor typeDescriptor) { - TypeDescriptor type = typeDescriptor; - do { - SchemaProvider schemaProvider = providers.get(type); - if (schemaProvider != null) { - return (SerializableFunction) schemaProvider.toRowFunction(type); - } - Class superClass = type.getRawType().getSuperclass(); - if (superClass == null || superClass.equals(Object.class)) { - return null; - } - type = TypeDescriptor.of(superClass); - } while (true); + @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); + return schemaProvider != null ? schemaProvider.toRowFunction(typeDescriptor) : null; } @Override public @Nullable SerializableFunction fromRowFunction( TypeDescriptor typeDescriptor) { - TypeDescriptor type = typeDescriptor; - do { - SchemaProvider schemaProvider = providers.get(type); - if (schemaProvider != null) { - return (SerializableFunction) schemaProvider.fromRowFunction(type); - } - Class superClass = type.getRawType().getSuperclass(); - if (superClass == null || superClass.equals(Object.class)) { - return null; - } - type = TypeDescriptor.of(superClass); - } while (true); + @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); + return schemaProvider != null ? schemaProvider.fromRowFunction(typeDescriptor) : null; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java index 54e2a595fa71..c3a71bbb454b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java @@ -160,7 +160,8 @@ public FunctionAndType(Type outputType, Function function) { public FunctionAndType(TypeDescriptor outputType, Function function) { this( - StaticSchemaInference.fieldFromType(outputType, new EmptyFieldValueTypeSupplier()), + StaticSchemaInference.fieldFromType( + outputType, new EmptyFieldValueTypeSupplier(), Collections.emptyMap()), function); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index d7fddd8abfed..74e97bad4f0f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -53,6 +53,7 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; @@ -63,6 +64,7 @@ import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for managing AutoValue schemas. */ @@ -70,6 +72,7 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) +@Internal public class AutoValueUtils { public static TypeDescriptor getBaseAutoValueClass(TypeDescriptor typeDescriptor) { // AutoValue extensions may be nested @@ -161,7 +164,7 @@ private static boolean matchConstructor( // Verify that constructor parameters match (name and type) the inferred schema. for (Parameter parameter : constructor.getParameters()) { FieldValueTypeInformation type = typeMap.get(parameter.getName()); - if (type == null || type.getRawType() != parameter.getType()) { + if (type == null || !type.getRawType().equals(parameter.getType())) { valid = false; break; } @@ -178,7 +181,7 @@ private static boolean matchConstructor( } name = name.substring(0, name.length() - 1); FieldValueTypeInformation type = typeMap.get(name); - if (type == null || type.getRawType() != parameter.getType()) { + if (type == null || !type.getRawType().equals(parameter.getType())) { return false; } } @@ -196,11 +199,12 @@ private static boolean matchConstructor( return null; } - Map setterTypes = - ReflectUtils.getMethods(builderClass).stream() - .filter(ReflectUtils::isSetter) - .map(FieldValueTypeInformation::forSetter) - .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); + Map boundTypes = ReflectUtils.getAllBoundTypes(TypeDescriptor.of(builderClass)); + Map setterTypes = Maps.newHashMap(); + ReflectUtils.getMethods(builderClass).stream() + .filter(ReflectUtils::isSetter) + .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) + .forEach(fv -> setterTypes.putIfAbsent(fv.getName(), fv)); List setterMethods = Lists.newArrayList(); // The builder methods to call in order. @@ -321,7 +325,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Duplication.SINGLE, typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(parameter.getType())), + .convert(TypeDescriptor.of(parameter.getParameterizedType())), MethodInvocation.invoke(new ForLoadedMethod(setterMethod)), Removal.SINGLE); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index c2b33c2d2315..65adc33a1bab 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -344,19 +344,22 @@ protected Type convertArray(TypeDescriptor type) { @Override protected Type convertCollection(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertList(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createIterableType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @@ -687,7 +690,8 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); @@ -707,7 +711,8 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { @@ -726,7 +731,8 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { @@ -745,8 +751,8 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); - final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0, Collections.emptyMap()); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1, Collections.emptyMap()); Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); @@ -1038,8 +1044,9 @@ protected StackManipulation convertIterable(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type); + .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + final TypeDescriptor iterableElementType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); if (!iterableElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = new ForLoadedType( @@ -1060,8 +1067,9 @@ protected StackManipulation convertCollection(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + final TypeDescriptor collectionElementType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); if (!collectionElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1083,8 +1091,9 @@ protected StackManipulation convertList(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + final TypeDescriptor collectionElementType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); StackManipulation readTrasformedValue = readValue; if (!collectionElementType.hasUnresolvedParameters()) { @@ -1113,11 +1122,17 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { Type rowKeyType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0)); - final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0); + getFactory() + .createTypeConversion(false) + .convert(ReflectUtils.getMapType(type, 0, Collections.emptyMap())); + final TypeDescriptor keyElementType = + ReflectUtils.getMapType(type, 0, Collections.emptyMap()); Type rowValueType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1)); - final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1); + getFactory() + .createTypeConversion(false) + .convert(ReflectUtils.getMapType(type, 1, Collections.emptyMap())); + final TypeDescriptor valueElementType = + ReflectUtils.getMapType(type, 1, Collections.emptyMap()); StackManipulation readTrasformedValue = readValue; if (!keyElementType.hasUnresolvedParameters() @@ -1475,7 +1490,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Parameter parameter = parameters.get(i); ForLoadedType convertedType = new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(parameter.getType()))); + (Class) convertType.convert(TypeDescriptor.of(parameter.getParameterizedType()))); // The instruction to read the parameter. Use the fieldMapping to reorder parameters as // necessary. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java index 7f2403035d97..e98a0b9495cf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Type; +import java.util.Collections; import java.util.ServiceLoader; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -36,6 +37,7 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.JavaFieldSchema.JavaFieldTypeSupplier; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; @@ -56,6 +58,7 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) +@Internal public class ConvertHelpers { private static class SchemaInformationProviders { private static final ServiceLoader INSTANCE = @@ -148,7 +151,8 @@ public static SerializableFunction getConvertPrimitive( TypeDescriptor outputTypeDescriptor, TypeConversionsFactory typeConversionsFactory) { FieldType expectedFieldType = - StaticSchemaInference.fieldFromType(outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE); + StaticSchemaInference.fieldFromType( + outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE, Collections.emptyMap()); if (!expectedFieldType.equals(fieldType)) { throw new IllegalArgumentException( "Element argument type " diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 911f79f6eeed..83f6b5c928d8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -22,6 +22,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.lang.reflect.Type; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -42,6 +43,7 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; @@ -61,11 +63,15 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) +@Internal public class JavaBeanUtils { /** Create a {@link Schema} for a Java Bean class. */ public static Schema schemaFromJavaBeanClass( - TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { - return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); + TypeDescriptor typeDescriptor, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return StaticSchemaInference.schemaFromClass( + typeDescriptor, fieldValueTypeSupplier, boundTypes); } private static final String CONSTRUCTOR_HELP_STRING = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 571b9c690900..1e60c9312cb3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -49,6 +49,7 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; @@ -70,11 +71,15 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) }) +@Internal public class POJOUtils { public static Schema schemaFromPojoClass( - TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { - return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); + TypeDescriptor typeDescriptor, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return StaticSchemaInference.schemaFromClass( + typeDescriptor, fieldValueTypeSupplier, boundTypes); } // Static ByteBuddy instance used by all helpers. @@ -301,7 +306,7 @@ public static SchemaUserTypeCreator createStaticCreator( field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getType()))); + .convert(TypeDescriptor.of(field.getGenericType()))); builder = implementGetterMethods(builder, field, typeInformation.getName(), typeConversionsFactory); try { @@ -383,7 +388,7 @@ private static FieldValueSetter createSetter( field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getType()))); + .convert(TypeDescriptor.of(field.getGenericType()))); builder = implementSetterMethods(builder, field, typeConversionsFactory); try { return builder @@ -491,7 +496,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readField) - .convert(TypeDescriptor.of(field.getType())), + .convert(TypeDescriptor.of(field.getGenericType())), // Now update the field and return void. FieldAccess.forField(new ForLoadedField(field)).write(), MethodReturn.VOID); @@ -546,7 +551,8 @@ public ByteCodeAppender appender(final Target implementationTarget) { Field field = fields.get(i); ForLoadedType convertedType = - new ForLoadedType((Class) convertType.convert(TypeDescriptor.of(field.getType()))); + new ForLoadedType( + (Class) convertType.convert(TypeDescriptor.of(field.getGenericType()))); // The instruction to read the parameter. StackManipulation readParameter = @@ -563,7 +569,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(field.getType())), + .convert(TypeDescriptor.of(field.getGenericType())), // Now update the field. FieldAccess.forField(new ForLoadedField(field)).write()); stackManipulation = new StackManipulation.Compound(stackManipulation, updateField); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index 4349a04c28ad..32cfa5689193 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -26,16 +26,17 @@ import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.security.InvalidParameterException; import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -88,14 +89,23 @@ public static List getMethods(Class clazz) { return DECLARED_METHODS.computeIfAbsent( clazz, c -> { - return Arrays.stream(c.getDeclaredMethods()) - .filter( - m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must - // ignore. - .filter(m -> !Modifier.isPrivate(m.getModifiers())) - .filter(m -> !Modifier.isProtected(m.getModifiers())) - .filter(m -> !Modifier.isStatic(m.getModifiers())) - .collect(Collectors.toList()); + List methods = Lists.newArrayList(); + do { + if (c.getPackage() != null && c.getPackage().getName().startsWith("java.")) { + break; // skip java built-in classes + } + Arrays.stream(c.getDeclaredMethods()) + .filter( + m -> + !m.isBridge()) // Covariant overloads insert bridge functions, which we must + // ignore. + .filter(m -> !Modifier.isPrivate(m.getModifiers())) + .filter(m -> !Modifier.isProtected(m.getModifiers())) + .filter(m -> !Modifier.isStatic(m.getModifiers())) + .forEach(methods::add); + c = c.getSuperclass(); + } while (c != null); + return methods; }); } @@ -201,7 +211,8 @@ public static String stripSetterPrefix(String method) { } /** For an array T[] or a subclass of Iterable, return a TypeDescriptor describing T. */ - public static @Nullable TypeDescriptor getIterableComponentType(TypeDescriptor valueType) { + public static @Nullable TypeDescriptor getIterableComponentType( + TypeDescriptor valueType, Map boundTypes) { TypeDescriptor componentType = null; if (valueType.isArray()) { Type component = valueType.getComponentType().getType(); @@ -215,7 +226,7 @@ public static String stripSetterPrefix(String method) { ParameterizedType ptype = (ParameterizedType) collection.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); checkArgument(params.length == 1); - componentType = TypeDescriptor.of(params[0]); + componentType = TypeDescriptor.of(resolveType(params[0], boundTypes)); } else { throw new RuntimeException("Collection parameter is not parameterized!"); } @@ -223,14 +234,15 @@ public static String stripSetterPrefix(String method) { return componentType; } - public static TypeDescriptor getMapType(TypeDescriptor valueType, int index) { + public static TypeDescriptor getMapType( + TypeDescriptor valueType, int index, Map boundTypes) { TypeDescriptor mapType = null; if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) { TypeDescriptor> map = valueType.getSupertype(Map.class); if (map.getType() instanceof ParameterizedType) { ParameterizedType ptype = (ParameterizedType) map.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - mapType = TypeDescriptor.of(params[index]); + mapType = TypeDescriptor.of(resolveType(params[index], boundTypes)); } else { throw new RuntimeException("Map type is not parameterized! " + map); } @@ -243,4 +255,49 @@ public static TypeDescriptor boxIfPrimitive(TypeDescriptor typeDescriptor) { ? TypeDescriptor.of(Primitives.wrap(typeDescriptor.getRawType())) : typeDescriptor; } + + /** + * If this (or a base class)is a paremeterized type, return a map of all TypeVariable->Type + * bindings. This allows us to resolve types in any contained fields or methods. + */ + public static Map getAllBoundTypes(TypeDescriptor typeDescriptor) { + Map boundParameters = Maps.newHashMap(); + TypeDescriptor currentType = typeDescriptor; + do { + if (currentType.getType() instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) currentType.getType(); + TypeVariable[] typeVariables = currentType.getRawType().getTypeParameters(); + Type[] typeArguments = parameterizedType.getActualTypeArguments(); + ; + if (typeArguments.length != typeVariables.length) { + throw new RuntimeException("Unmatching arguments lengths in type " + typeDescriptor); + } + for (int i = 0; i < typeVariables.length; ++i) { + boundParameters.put(typeVariables[i], typeArguments[i]); + } + } + Type superClass = currentType.getRawType().getGenericSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + break; + } + currentType = TypeDescriptor.of(superClass); + } while (true); + return boundParameters; + } + + public static Type resolveType(Type type, Map boundTypes) { + TypeDescriptor typeDescriptor = TypeDescriptor.of(type); + if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Iterable.class)) + || typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { + // Don't resolve these as we special case map and interable. + return type; + } + + if (type instanceof TypeVariable) { + TypeVariable typeVariable = (TypeVariable) type; + return Preconditions.checkArgumentNotNull(boundTypes.get(typeVariable)); + } else { + return type; + } + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java index 196ee6f86593..275bc41be53d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java @@ -19,7 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.Arrays; @@ -29,10 +29,12 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.ReadableInstant; @@ -42,6 +44,7 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) +@Internal public class StaticSchemaInference { public static List sortBySchema( List types, Schema schema) { @@ -85,14 +88,17 @@ enum MethodType { * public getter methods, or special annotations on the class. */ public static Schema schemaFromClass( - TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { - return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>()); + TypeDescriptor typeDescriptor, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>(), boundTypes); } private static Schema schemaFromClass( TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas) { + Map, Schema> alreadyVisitedSchemas, + Map boundTypes) { if (alreadyVisitedSchemas.containsKey(typeDescriptor)) { Schema existingSchema = alreadyVisitedSchemas.get(typeDescriptor); if (existingSchema == null) { @@ -106,7 +112,7 @@ private static Schema schemaFromClass( Schema.Builder builder = Schema.builder(); for (FieldValueTypeInformation type : fieldValueTypeSupplier.get(typeDescriptor)) { Schema.FieldType fieldType = - fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas); + fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes); Schema.Field f = type.isNullable() ? Schema.Field.nullable(type.getName(), fieldType) @@ -123,15 +129,18 @@ private static Schema schemaFromClass( /** Map a Java field type to a Beam Schema FieldType. */ public static Schema.FieldType fieldFromType( - TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) { - return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>()); + TypeDescriptor type, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>(), boundTypes); } // TODO(https://github.com/apache/beam/issues/21567): support type inference for logical types private static Schema.FieldType fieldFromType( TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas) { + Map, Schema> alreadyVisitedSchemas, + Map boundTypes) { FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType()); if (primitiveType != null) { return primitiveType; @@ -152,27 +161,25 @@ private static Schema.FieldType fieldFromType( } else { // Otherwise this is an array type. return FieldType.array( - fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas)); + fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); } } else if (type.isSubtypeOf(TypeDescriptor.of(Map.class))) { - TypeDescriptor> map = type.getSupertype(Map.class); - if (map.getType() instanceof ParameterizedType) { - ParameterizedType ptype = (ParameterizedType) map.getType(); - java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - checkArgument(params.length == 2); - FieldType keyType = - fieldFromType( - TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas); - FieldType valueType = - fieldFromType( - TypeDescriptor.of(params[1]), fieldValueTypeSupplier, alreadyVisitedSchemas); - checkArgument( - keyType.getTypeName().isPrimitiveType(), - "Only primitive types can be map keys. type: " + keyType.getTypeName()); - return FieldType.map(keyType, valueType); - } else { - throw new RuntimeException("Cannot infer schema from unparameterized map."); - } + FieldType keyType = + fieldFromType( + ReflectUtils.getMapType(type, 0, boundTypes), + fieldValueTypeSupplier, + alreadyVisitedSchemas, + boundTypes); + FieldType valueType = + fieldFromType( + ReflectUtils.getMapType(type, 1, boundTypes), + fieldValueTypeSupplier, + alreadyVisitedSchemas, + boundTypes); + checkArgument( + keyType.getTypeName().isPrimitiveType(), + "Only primitive types can be map keys. type: " + keyType.getTypeName()); + return FieldType.map(keyType, valueType); } else if (type.isSubtypeOf(TypeDescriptor.of(CharSequence.class))) { return FieldType.STRING; } else if (type.isSubtypeOf(TypeDescriptor.of(ReadableInstant.class))) { @@ -180,26 +187,22 @@ private static Schema.FieldType fieldFromType( } else if (type.isSubtypeOf(TypeDescriptor.of(ByteBuffer.class))) { return FieldType.BYTES; } else if (type.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { - TypeDescriptor> iterable = type.getSupertype(Iterable.class); - if (iterable.getType() instanceof ParameterizedType) { - ParameterizedType ptype = (ParameterizedType) iterable.getType(); - java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - checkArgument(params.length == 1); - // TODO: should this be AbstractCollection? - if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { - return FieldType.array( - fieldFromType( - TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); - } else { - return FieldType.iterable( - fieldFromType( - TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); - } + FieldType elementType = + fieldFromType( + Preconditions.checkArgumentNotNull( + ReflectUtils.getIterableComponentType(type, boundTypes)), + fieldValueTypeSupplier, + alreadyVisitedSchemas, + boundTypes); + // TODO: should this be AbstractCollection? + if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { + return FieldType.array(elementType); } else { - throw new RuntimeException("Cannot infer schema from unparameterized collection."); + return FieldType.iterable(elementType); } } else { - return FieldType.row(schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas)); + return FieldType.row( + schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java index d0ee623dea7c..49fd2bfe2259 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java @@ -28,6 +28,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Map; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; @@ -39,6 +40,7 @@ import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; import org.joda.time.DateTime; import org.joda.time.Instant; @@ -886,4 +888,151 @@ public void testSchema_SchemaFieldDescription() throws NoSuchSchemaException { assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("lng"), schema.getField("lng")); assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("str"), schema.getField("str")); } + + @AutoValue + @DefaultSchema(AutoValueSchema.class) + abstract static class ParameterizedAutoValue { + abstract W getValue1(); + + abstract T getValue2(); + + abstract V getValue3(); + + abstract X getValue4(); + } + + @Test + public void testAutoValueWithTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @DefaultSchema(AutoValueSchema.class) + abstract static class ParameterizedAutoValueSubclass + extends ParameterizedAutoValue { + abstract T getValue5(); + } + + @Test + public void testAutoValueWithInheritedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .addInt16Field("value5") + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @AutoValue + @DefaultSchema(AutoValueSchema.class) + abstract static class NestedParameterizedCollectionAutoValue { + abstract Iterable getNested(); + + abstract Map getMap(); + } + + @Test + public void testAutoValueWithNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + NestedParameterizedCollectionAutoValue< + ParameterizedAutoValue, String>> + typeDescriptor = + new TypeDescriptor< + NestedParameterizedCollectionAutoValue< + ParameterizedAutoValue, String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.row(expectedInnerSchema)) + .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testAutoValueWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + NestedParameterizedCollectionAutoValue< + Iterable>, String>> + typeDescriptor = + new TypeDescriptor< + NestedParameterizedCollectionAutoValue< + Iterable>, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) + .addMapField( + "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @AutoValue + @DefaultSchema(AutoValueSchema.class) + abstract static class NestedParameterizedAutoValue { + abstract T getNested(); + } + + @Test + public void testAutoValueWithNestedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + NestedParameterizedAutoValue< + ParameterizedAutoValue>> + typeDescriptor = + new TypeDescriptor< + NestedParameterizedAutoValue< + ParameterizedAutoValue>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder().addRowField("nested", expectedInnerSchema).build(); + assertTrue(expectedSchema.equivalent(schema)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 5313feb5c6c0..2252c3aef0db 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -68,6 +68,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBeanWithAnnotations; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -625,4 +626,127 @@ public void testSetterConstructionWithRenamedFields() throws NoSuchSchemaExcepti assertEquals( registry.getFromRowFunction(BeanWithCaseFormat.class).apply(row), beanWithCaseFormat); } + + @Test + public void testBeanWithTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> + typeDescriptor = + new TypeDescriptor< + TestJavaBeans.SimpleParameterizedBean>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_BEAN_SCHEMA) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testBeanWithInheritedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_BEAN_SCHEMA) + .addInt16Field("value5") + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testBeanWithNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestJavaBeans.NestedParameterizedCollectionBean< + TestJavaBeans.SimpleParameterizedBean, String>> + typeDescriptor = + new TypeDescriptor< + TestJavaBeans.NestedParameterizedCollectionBean< + TestJavaBeans.SimpleParameterizedBean, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_BEAN_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", Schema.FieldType.row(expectedInnerSchema)) + .addMapField("map", Schema.FieldType.STRING, Schema.FieldType.row(expectedInnerSchema)) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testBeanWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestJavaBeans.NestedParameterizedCollectionBean< + Iterable>, + String>> + typeDescriptor = + new TypeDescriptor< + TestJavaBeans.NestedParameterizedCollectionBean< + Iterable< + TestJavaBeans.SimpleParameterizedBean>, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_BEAN_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField( + "nested", Schema.FieldType.iterable(Schema.FieldType.row(expectedInnerSchema))) + .addMapField( + "map", + Schema.FieldType.STRING, + Schema.FieldType.iterable(Schema.FieldType.row(expectedInnerSchema))) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testBeanWithNestedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestJavaBeans.NestedParameterizedBean< + TestJavaBeans.SimpleParameterizedBean>> + typeDescriptor = + new TypeDescriptor< + TestJavaBeans.NestedParameterizedBean< + TestJavaBeans.SimpleParameterizedBean>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_BEAN_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder().addRowField("nested", expectedInnerSchema).build(); + assertTrue(expectedSchema.equivalent(schema)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 11bef79b26f7..70bc3030924b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -76,6 +76,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -781,4 +782,123 @@ public void testCircularNestedPOJOThrows() throws NoSuchSchemaException { thrown.getMessage(), containsString("TestPOJOs$FirstCircularNestedPOJO")); } + + @Test + public void testPojoWithTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.SimpleParameterizedPOJO>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithInheritedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .addInt16Field("value5") + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + TestPOJOs.SimpleParameterizedPOJO, String>> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + TestPOJOs.SimpleParameterizedPOJO, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.row(expectedInnerSchema)) + .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + Iterable>, + String>> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + Iterable>, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) + .addMapField( + "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithNestedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestPOJOs.NestedParameterizedPOJO< + TestPOJOs.SimpleParameterizedPOJO>> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.NestedParameterizedPOJO< + TestPOJOs.SimpleParameterizedPOJO>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder().addRowField("nested", expectedInnerSchema).build(); + assertTrue(expectedSchema.equivalent(schema)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java index 021e39b84849..e0a45c2c82fe 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java @@ -34,6 +34,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -65,7 +66,9 @@ public class JavaBeanUtilsTest { public void testNullable() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -74,7 +77,9 @@ public void testNullable() { public void testSimpleBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(SIMPLE_BEAN_SCHEMA, schema); } @@ -82,7 +87,9 @@ public void testSimpleBean() { public void testNestedBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_BEAN_SCHEMA, schema); } @@ -90,7 +97,9 @@ public void testNestedBean() { public void testPrimitiveArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_BEAN_SCHEMA, schema); } @@ -98,7 +107,9 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_BEAN_SCHEMA, schema); } @@ -106,7 +117,9 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_BEAN_SCHEMA, schema); } @@ -114,7 +127,9 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_BEAN_SCHEMA, schema); } @@ -122,7 +137,9 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_BEAN_SCHEMA, schema); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java index 723353ed8d15..46c098dddaeb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java @@ -35,6 +35,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -71,7 +72,9 @@ public class POJOUtilsTest { public void testNullables() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -80,7 +83,9 @@ public void testNullables() { public void testSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); assertEquals(SIMPLE_POJO_SCHEMA, schema); } @@ -88,7 +93,9 @@ public void testSimplePOJO() { public void testNestedPOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_SCHEMA, schema); } @@ -97,7 +104,8 @@ public void testNestedPOJOWithSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE); + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_WITH_SIMPLE_POJO_SCHEMA, schema); } @@ -105,7 +113,9 @@ public void testNestedPOJOWithSimplePOJO() { public void testPrimitiveArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_POJO_SCHEMA, schema); } @@ -113,7 +123,9 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_POJO_SCHEMA, schema); } @@ -121,7 +133,9 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_POJO_SCHEMA, schema); } @@ -129,7 +143,9 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_POJO_SCHEMA, schema); } @@ -137,7 +153,9 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_POJO_SCHEMA, schema); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java index b5ad6f989d9e..cbc976144971 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java @@ -1397,4 +1397,95 @@ public void setValue(@Nullable Float value) { Schema.Field.nullable("value", FieldType.FLOAT) .withDescription("This value is the value stored in the object as a float.")) .build(); + + @DefaultSchema(JavaBeanSchema.class) + public static class SimpleParameterizedBean { + @Nullable private W value1; + @Nullable private T value2; + @Nullable private V value3; + @Nullable private X value4; + + public W getValue1() { + return value1; + } + + public void setValue1(W value1) { + this.value1 = value1; + } + + public T getValue2() { + return value2; + } + + public void setValue2(T value2) { + this.value2 = value2; + } + + public V getValue3() { + return value3; + } + + public void setValue3(V value3) { + this.value3 = value3; + } + + public X getValue4() { + return value4; + } + + public void setValue4(X value4) { + this.value4 = value4; + } + } + + @DefaultSchema(JavaBeanSchema.class) + public static class SimpleParameterizedBeanSubclass + extends SimpleParameterizedBean { + @Nullable private T value5; + + public SimpleParameterizedBeanSubclass() {} + + public T getValue5() { + return value5; + } + + public void setValue5(T value5) { + this.value5 = value5; + } + } + + @DefaultSchema(JavaBeanSchema.class) + public static class NestedParameterizedCollectionBean { + private Iterable nested; + private Map map; + + public Iterable getNested() { + return nested; + } + + public Map getMap() { + return map; + } + + public void setNested(Iterable nested) { + this.nested = nested; + } + + public void setMap(Map map) { + this.map = map; + } + } + + @DefaultSchema(JavaBeanSchema.class) + public static class NestedParameterizedBean { + private T nested; + + public T getNested() { + return nested; + } + + public void setNested(T nested) { + this.nested = nested; + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java index 789de02adee8..ce7409365d09 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java @@ -495,6 +495,125 @@ public int hashCode() { .addStringField("stringBuilder") .build(); + @DefaultSchema(JavaFieldSchema.class) + public static class SimpleParameterizedPOJO { + public W value1; + public T value2; + public V value3; + public X value4; + + public SimpleParameterizedPOJO() {} + + public SimpleParameterizedPOJO(W value1, T value2, V value3, X value4) { + this.value1 = value1; + this.value2 = value2; + this.value3 = value3; + this.value4 = value4; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SimpleParameterizedPOJO)) { + return false; + } + SimpleParameterizedPOJO that = (SimpleParameterizedPOJO) o; + return Objects.equals(value1, that.value1) + && Objects.equals(value2, that.value2) + && Objects.equals(value3, that.value3) + && Objects.equals(value4, that.value4); + } + + @Override + public int hashCode() { + return Objects.hash(value1, value2, value3, value4); + } + } + + @DefaultSchema(JavaFieldSchema.class) + public static class SimpleParameterizedPOJOSubclass + extends SimpleParameterizedPOJO { + public T value5; + + public SimpleParameterizedPOJOSubclass() {} + + public SimpleParameterizedPOJOSubclass(T value5) { + this.value5 = value5; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SimpleParameterizedPOJOSubclass)) { + return false; + } + SimpleParameterizedPOJOSubclass that = (SimpleParameterizedPOJOSubclass) o; + return Objects.equals(value5, that.value5); + } + + @Override + public int hashCode() { + return Objects.hash(value4); + } + } + + @DefaultSchema(JavaFieldSchema.class) + public static class NestedParameterizedCollectionPOJO { + public Iterable nested; + public Map map; + + public NestedParameterizedCollectionPOJO(Iterable nested, Map map) { + this.nested = nested; + this.map = map; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NestedParameterizedCollectionPOJO)) { + return false; + } + NestedParameterizedCollectionPOJO that = (NestedParameterizedCollectionPOJO) o; + return Objects.equals(nested, that.nested) && Objects.equals(map, that.map); + } + + @Override + public int hashCode() { + return Objects.hash(nested, map); + } + } + + @DefaultSchema(JavaFieldSchema.class) + public static class NestedParameterizedPOJO { + public T nested; + + public NestedParameterizedPOJO(T nested) { + this.nested = nested; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NestedParameterizedPOJO)) { + return false; + } + NestedParameterizedPOJO that = (NestedParameterizedPOJO) o; + return Objects.equals(nested, that.nested); + } + + @Override + public int hashCode() { + return Objects.hash(nested); + } + } /** A POJO containing a nested class. * */ @DefaultSchema(JavaFieldSchema.class) public static class NestedPOJO { @@ -887,7 +1006,7 @@ public boolean equals(@Nullable Object o) { if (this == o) { return true; } - if (!(o instanceof PojoWithNestedArray)) { + if (!(o instanceof PojoWithIterable)) { return false; } PojoWithIterable that = (PojoWithIterable) o; diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java index 0a82663c1771..1a530a3f6ca5 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java @@ -78,8 +78,8 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc // Generate a method call to create and invoke the SpecificRecord's constructor. . MethodCall construct = MethodCall.construct(baseConstructor); - for (int i = 0; i < baseConstructor.getParameterTypes().length; ++i) { - Class baseType = baseConstructor.getParameterTypes()[i]; + for (int i = 0; i < baseConstructor.getGenericParameterTypes().length; ++i) { + Type baseType = baseConstructor.getGenericParameterTypes()[i]; construct = construct.with(readAndConvertParameter(baseType, i), baseType); } @@ -110,7 +110,7 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc } private static StackManipulation readAndConvertParameter( - Class constructorParameterType, int index) { + Type constructorParameterType, int index) { TypeConversionsFactory typeConversionsFactory = new AvroUtils.AvroTypeConversionFactory(); // The types in the AVRO-generated constructor might be the types returned by Beam's Row class, diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 1b1c45969307..1324d254e44e 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -814,6 +814,9 @@ public List get(TypeDescriptor typeDescriptor) { @Override public List get(TypeDescriptor typeDescriptor, Schema schema) { + Map boundTypes = + ReflectUtils.getAllBoundTypes(typeDescriptor); + Map mapping = getMapping(schema); List methods = ReflectUtils.getMethods(typeDescriptor.getRawType()); List types = Lists.newArrayList(); @@ -821,7 +824,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i); + FieldValueTypeInformation.forGetter(method, i, boundTypes); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -865,13 +868,16 @@ private Map getMapping(Schema schema) { private static final class AvroPojoFieldValueTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { + Map boundTypes = + ReflectUtils.getAllBoundTypes(typeDescriptor); List classFields = ReflectUtils.getFields(typeDescriptor.getRawType()); Map types = Maps.newHashMap(); for (int i = 0; i < classFields.size(); ++i) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { - FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); + FieldValueTypeInformation typeInformation = + FieldValueTypeInformation.forField(f, i, boundTypes); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index d159e9de44a8..fcfc40403b43 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -39,6 +39,7 @@ import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -1045,7 +1046,8 @@ FieldValueSetter getProtoFieldValueSetter( } else { Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( - FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), + FieldValueTypeInformation.forSetter( + method, protoSetterPrefix(field.getType()), Collections.emptyMap()), new ProtoTypeConversionsFactory()); } } diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index faf3ad407af5..4b8d51abdea6 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -23,6 +23,7 @@ import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; import java.lang.reflect.Method; +import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.beam.sdk.extensions.protobuf.ProtoByteBuddyUtils.ProtoTypeConversionsFactory; @@ -72,7 +73,8 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) + .withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. types.add( @@ -82,7 +84,9 @@ public List get(TypeDescriptor typeDescriptor, Sch } else { // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); - types.add(FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + types.add( + FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) + .withName(field.getName())); } } return types; diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java index d5f1745a9a2c..64f600903d87 100644 --- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java +++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java @@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -389,7 +390,8 @@ private Schema generateSchemaDirectly( fieldName, StaticSchemaInference.fieldFromType( TypeDescriptor.of(field.getClass()), - JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE)); + JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap())); } counter++; diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 5f4e195f227f..73b3709da832 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -242,10 +242,11 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), ""); + return FieldValueTypeInformation.forSetter(factoryMethods.get(0), "", Collections.emptyMap()); } else { try { - return FieldValueTypeInformation.forField(type.getDeclaredField(fieldName), 0); + return FieldValueTypeInformation.forField( + type.getDeclaredField(fieldName), 0, Collections.emptyMap()); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); } From 7d0bfd0d6ef453fea4672d3c7752cc02e22de351 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:38:59 -0700 Subject: [PATCH 36/82] Bump google.golang.org/protobuf from 1.34.2 to 1.35.1 in /sdks (#32799) Bumps google.golang.org/protobuf from 1.34.2 to 1.35.1. --- updated-dependencies: - dependency-name: google.golang.org/protobuf dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 2 +- sdks/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 0b5ac98df404..9aa839d67e1a 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -61,7 +61,7 @@ require ( google.golang.org/api v0.199.0 google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 google.golang.org/grpc v1.67.1 - google.golang.org/protobuf v1.34.2 + google.golang.org/protobuf v1.35.1 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/sdks/go.sum b/sdks/go.sum index db6d71b061b5..51e4d58c237d 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -1913,8 +1913,8 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= From 6b3a1b29f264b829745309b0b43639bd3625da76 Mon Sep 17 00:00:00 2001 From: claudevdm <33973061+claudevdm@users.noreply.github.com> Date: Wed, 16 Oct 2024 09:06:02 -0400 Subject: [PATCH 37/82] Bigquery fixes (#32780) * Bigquery fixes * Remove unnecessary comprehension loop --------- Co-authored-by: Claude --- .../enrichment_handlers/bigquery.py | 19 +- .../enrichment_handlers/bigquery_it_test.py | 300 +++++++++++------- 2 files changed, 199 insertions(+), 120 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 382ae123a81d..ea98fb6b0bbd 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -171,6 +171,14 @@ def _execute_query(self, query: str): except RuntimeError as e: raise RuntimeError(f"Could not complete the query request: {query}. {e}") + def create_row_key(self, row: beam.Row): + if self.condition_value_fn: + return tuple(self.condition_value_fn(row)) + if self.fields: + row_dict = row._asdict() + return (tuple(row_dict[field] for field in self.fields)) + raise ValueError("Either fields or condition_value_fn must be specified") + def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): if isinstance(request, List): values = [] @@ -180,7 +188,7 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): raw_query = self.query_template if batch_size > 1: batched_condition_template = ' or '.join( - [self.row_restriction_template] * batch_size) + [fr'({self.row_restriction_template})'] * batch_size) raw_query = self.query_template.replace( self.row_restriction_template, batched_condition_template) for req in request: @@ -194,14 +202,15 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): "Make sure the values passed in `fields` are the " "keys in the input `beam.Row`." + str(e)) values.extend(current_values) - requests_map.update((val, req) for val in current_values) + requests_map[self.create_row_key(req)] = req query = raw_query.format(*values) responses_dict = self._execute_query(query) for response in responses_dict: - for value in response.values(): - if value in requests_map: - responses.append((requests_map[value], beam.Row(**response))) + response_row = beam.Row(**response) + response_key = self.create_row_key(response_row) + if response_key in requests_map: + responses.append((requests_map[response_key], response_row)) return responses else: request_dict = request._asdict() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py index 0b8a384b934d..dd99e386555e 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import functools import logging +import secrets +import time import unittest from unittest.mock import MagicMock @@ -22,7 +25,11 @@ import apache_beam as beam from apache_beam.coders import coders +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to # pylint: disable=ungrouped-imports try: @@ -31,8 +38,7 @@ from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigquery import \ BigQueryEnrichmentHandler - from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import \ - ValidateResponse + from apitools.base.py.exceptions import HttpError except ImportError: raise unittest.SkipTest( 'Google Cloud BigQuery dependencies are not installed.') @@ -40,24 +46,101 @@ _LOGGER = logging.getLogger(__name__) -def query_fn(row: beam.Row): - query = ( - "SELECT * FROM " - "`apache-beam-testing.my_ecommerce.product_details`" - " WHERE id = '{}'".format(row.id)) # type: ignore[attr-defined] - return query - - def condition_value_fn(row: beam.Row): return [row.id] # type: ignore[attr-defined] +def query_fn(table, row: beam.Row): + return f"SELECT * FROM `{table}` WHERE id = {row.id}" # type: ignore[attr-defined] + + +@pytest.mark.uses_testcontainer +class BigQueryEnrichmentIT(unittest.TestCase): + bigquery_dataset_id = 'python_enrichment_transform_read_table_' + project = "apache-beam-testing" + + @classmethod + def setUpClass(cls): + cls.bigquery_client = BigQueryWrapper() + cls.dataset_id = '%s%d%s' % ( + cls.bigquery_dataset_id, int(time.time()), secrets.token_hex(3)) + cls.bigquery_client.get_or_create_dataset(cls.project, cls.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", cls.dataset_id, cls.project) + + @classmethod + def tearDownClass(cls): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=cls.project, datasetId=cls.dataset_id, deleteContents=True) + try: + _LOGGER.debug( + "Deleting dataset %s in project %s", cls.dataset_id, cls.project) + cls.bigquery_client.client.datasets.Delete(request) + except HttpError: + _LOGGER.warning( + 'Failed to clean up dataset %s in project %s', + cls.dataset_id, + cls.project) + + @pytest.mark.uses_testcontainer -class TestBigQueryEnrichmentIT(unittest.TestCase): +class TestBigQueryEnrichmentIT(BigQueryEnrichmentIT): + table_data = [ + { + "id": 1, "name": "A", 'quantity': 2, 'distribution_center_id': 3 + }, + { + "id": 2, "name": "B", 'quantity': 3, 'distribution_center_id': 1 + }, + { + "id": 3, "name": "C", 'quantity': 10, 'distribution_center_id': 4 + }, + { + "id": 4, "name": "D", 'quantity': 1, 'distribution_center_id': 3 + }, + { + "id": 5, "name": "C", 'quantity': 100, 'distribution_center_id': 4 + }, + { + "id": 6, "name": "D", 'quantity': 11, 'distribution_center_id': 3 + }, + { + "id": 7, "name": "C", 'quantity': 7, 'distribution_center_id': 1 + }, + { + "id": 8, "name": "D", 'quantity': 4, 'distribution_center_id': 1 + }, + ] + + @classmethod + def create_table(cls, table_name): + fields = [('id', 'INTEGER'), ('name', 'STRING'), ('quantity', 'INTEGER'), + ('distribution_center_id', 'INTEGER')] + table_schema = bigquery.TableSchema() + for name, field_type in fields: + table_field = bigquery.TableFieldSchema() + table_field.name = name + table_field.type = field_type + table_schema.fields.append(table_field) + table = bigquery.Table( + tableReference=bigquery.TableReference( + projectId=cls.project, datasetId=cls.dataset_id, + tableId=table_name), + schema=table_schema) + request = bigquery.BigqueryTablesInsertRequest( + projectId=cls.project, datasetId=cls.dataset_id, table=table) + cls.bigquery_client.client.tables.Insert(request) + cls.bigquery_client.insert_rows( + cls.project, cls.dataset_id, table_name, cls.table_data) + cls.table_name = f"{cls.project}.{cls.dataset_id}.{table_name}" + + @classmethod + def setUpClass(cls): + super(TestBigQueryEnrichmentIT, cls).setUpClass() + cls.create_table('product_details') + def setUp(self) -> None: - self.project = 'apache-beam-testing' - self.condition_template = "id = '{}'" - self.table_name = "`apache-beam-testing.my_ecommerce.product_details`" + self.condition_template = "id = {}" self.retries = 3 self._start_container() @@ -82,123 +165,119 @@ def tearDown(self) -> None: self.client = None def test_bigquery_enrichment(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] fields = ['id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, - row_restriction_template=self.condition_template, + row_restriction_template="id = {}", table_name=self.table_name, fields=fields, - min_batch_size=2, + min_batch_size=1, max_batch_size=100, ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_query_fn(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_batched(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] + fields = ['id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] - handler = BigQueryEnrichmentHandler(project=self.project, query_fn=query_fn) + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template="id = {}", + table_name=self.table_name, + fields=fields, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_condition_value_fn(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_batched_multiple_fields(self): + expected_rows = [ + beam.Row(id=1, distribution_center_id=3, name="A", quantity=2), + beam.Row(id=2, distribution_center_id=1, name="B", quantity=3) ] + fields = ['id', 'distribution_center_id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, distribution_center_id=3), + beam.Row(id=2, distribution_center_id=1), ] handler = BigQueryEnrichmentHandler( project=self.project, - row_restriction_template=self.condition_template, + row_restriction_template="id = {} AND distribution_center_id = {}", table_name=self.table_name, - condition_value_fn=condition_value_fn, - min_batch_size=2, + fields=fields, + min_batch_size=8, max_batch_size=100, ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_condition_without_batch(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_with_query_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + fn = functools.partial(query_fn, self.table_name) + handler = BigQueryEnrichmentHandler(project=self.project, query_fn=fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_with_condition_value_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) + ] + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, row_restriction_template=self.condition_template, table_name=self.table_name, condition_value_fn=condition_value_fn, + min_batch_size=2, + max_batch_size=100, ) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) def test_bigquery_enrichment_bad_request(self): requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, @@ -231,18 +310,13 @@ def test_bigquery_enrichment_with_redis(self): requests. Since all requests are cached, it will return from there without making calls to the BigQuery service. """ - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' - ] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] handler = BigQueryEnrichmentHandler( project=self.project, @@ -253,11 +327,12 @@ def test_bigquery_enrichment_with_redis(self): max_batch_size=100, ) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_populate_cache = ( test_pipeline | beam.Create(requests) - | Enrichment(handler).with_redis_cache(self.host, self.port) - | beam.ParDo(ValidateResponse(expected_fields))) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_populate_cache, equal_to(expected_rows)) # manually check cache entry c = coders.StrUtf8Coder() @@ -268,20 +343,15 @@ def test_bigquery_enrichment_with_redis(self): raise ValueError("No cache entry found for %s" % key) actual = BigQueryEnrichmentHandler.__call__ - BigQueryEnrichmentHandler.__call__ = MagicMock( - return_value=( - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), - beam.Row())) + BigQueryEnrichmentHandler.__call__ = MagicMock(return_value=(beam.Row())) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_cached = ( test_pipeline | beam.Create(requests) - | Enrichment(handler).with_redis_cache(self.host, self.port) - | beam.ParDo(ValidateResponse(expected_fields))) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_cached, equal_to(expected_rows)) BigQueryEnrichmentHandler.__call__ = actual From 3e49714c9dc9ee0902bfae05719ca94199ce509a Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 16 Oct 2024 10:06:30 -0400 Subject: [PATCH 38/82] Change to rez --- website/www/site/content/en/blog/beam-summit-2024-overview.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/www/site/content/en/blog/beam-summit-2024-overview.md b/website/www/site/content/en/blog/beam-summit-2024-overview.md index 8d5eb209bea2..5cf922d69544 100644 --- a/website/www/site/content/en/blog/beam-summit-2024-overview.md +++ b/website/www/site/content/en/blog/beam-summit-2024-overview.md @@ -8,7 +8,7 @@ aliases: authors: - liferoad - damccorm - - rezarokni + - rez --- + +We are happy to present the new 2.60.0 release of Beam. +This release includes both improvements and new functionality. +See the [download page](/get-started/downloads/#2600-2024-10-17) for this release. + + + +For more information on changes in 2.60.0, check out the [detailed release notes](https://github.com/apache/beam/milestone/24). + +## Highlights + +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) +* [Managed Iceberg] Added support for streaming writes ([#32451](https://github.com/apache/beam/pull/32451)) +* [Managed Iceberg] Added auto-sharding for streaming writes ([#32612](https://github.com/apache/beam/pull/32612)) +* [Managed Iceberg] Added support for writing to dynamic destinations ([#32565](https://github.com/apache/beam/pull/32565)) + +## New Features / Improvements + +* Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). +* Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) +* Prism release binaries and container bootloaders are now being built with the latest Go 1.23 patch. ([#32575](https://github.com/apache/beam/pull/32575)) +* Prism + * Prism now supports Bundle Finalization. ([#32425](https://github.com/apache/beam/pull/32425)) +* Significantly improved performance of Kafka IO reads that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) by removing the data reshuffle from SDF implementation. ([#31682](https://github.com/apache/beam/pull/31682)). +* Added support for dynamic writing in MqttIO (Java) ([#19376](https://github.com/apache/beam/issues/19376)) +* Optimized Spark Runner parDo transform evaluator (Java) ([#32537](https://github.com/apache/beam/issues/32537)) +* [Managed Iceberg] More efficient manifest file writes/commits ([#32666](https://github.com/apache/beam/issues/32666)) + +## Breaking Changes + +* In Python, assert_that now throws if it is not in a pipeline context instead of silently succeeding ([#30771](https://github.com/apache/beam/pull/30771)) +* In Python and YAML, ReadFromJson now override the dtype from None to + an explicit False. Most notably, string values like `"123"` are preserved + as strings rather than silently coerced (and possibly truncated) to numeric + values. To retain the old behavior, pass `dtype=True` (or any other value + accepted by `pandas.read_json`). +* Users of KafkaIO Read transform that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) might encounter pipeline graph compatibility issues when updating the pipeline. To mitigate, set the `updateCompatibilityVersion` option to the SDK version used for the original pipeline, example `--updateCompatabilityVersion=2.58.1` + +## Deprecations + +* Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 2.60.0 release will warn users +when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) + +## Bugfixes + +* (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)). +* (Java, Python, Go) Fixed PeriodicSequence backlog bytes reporting, which was preventing Dataflow Runner autoscaling from functioning properly ([#32506](https://github.com/apache/beam/issues/32506)). +* (Java) Fix improper decoding of rows with schemas containing nullable fields when encoded with a schema with equal encoding positions but modified field order. ([#32388](https://github.com/apache/beam/issues/32388)). + +## Known Issues + +N/A + +For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md + +## List of Contributors + +According to git shortlog, the following people contributed to the 2.60.0 release. Thank you to all contributors! + +Ahmed Abualsaud, Aiden Grossman, Arun Pandian, Bartosz Zablocki, Chamikara Jayalath, Claire McGinty, DKPHUONG, Damon Douglass, Danny McCormick, Dip Patel, Ferran Fernández Garrido, Hai Joey Tran, Hyeonho Kim, Igor Bernstein, Israel Herraiz, Jack McCluskey, Jaehyeon Kim, Jeff Kinard, Jeffrey Kinard, Joey Tran, Kenneth Knowles, Kirill Berezin, Michel Davit, Minbo Bae, Naireen Hussain, Niel Markwick, Nito Buendia, Reeba Qureshi, Reuven Lax, Robert Bradshaw, Robert Burke, Rohit Sinha, Ryan Fu, Sam Whittle, Shunping Huang, Svetak Sundhar, Udaya Chathuranga, Vitaly Terentyev, Vlado Djerek, Yi Hu, Claude van der Merwe, XQ Hu, Martin Trieu, Valentyn Tymofieiev, twosom diff --git a/website/www/site/content/en/get-started/downloads.md b/website/www/site/content/en/get-started/downloads.md index 08614b8835c1..ff432996578d 100644 --- a/website/www/site/content/en/get-started/downloads.md +++ b/website/www/site/content/en/get-started/downloads.md @@ -96,10 +96,18 @@ versions denoted `0.x.y`. ## Releases +### 2.60.0 (2024-10-17) + +Official [source code download](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip). +[SHA-512](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip.sha512). +[signature](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip.asc). + +[Release notes](https://github.com/apache/beam/releases/tag/v2.60.0) + ### 2.59.0 (2024-09-11) -Official [source code download](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip). -[SHA-512](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip.sha512). -[signature](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip.asc). +Official [source code download](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip). +[SHA-512](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip.sha512). +[signature](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip.asc). [Release notes](https://github.com/apache/beam/releases/tag/v2.59.0) From 6e3516baf2894b806e9cd3592257ee896c03fe15 Mon Sep 17 00:00:00 2001 From: liferoad Date: Thu, 17 Oct 2024 21:33:33 -0400 Subject: [PATCH 64/82] Revert "Update pyproject.toml by using grpcio-tools==1.65.5" --- sdks/python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index a99599a2ce2b..4eb827297019 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -21,7 +21,7 @@ requires = [ "setuptools", "wheel>=0.36.0", - "grpcio-tools==1.65.5", + "grpcio-tools==1.62.1", "mypy-protobuf==3.5.0", # Avoid https://github.com/pypa/virtualenv/issues/2006 "distlib==0.3.7", From b412928be498065feae40bc8d14b79d9bcda6f30 Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Fri, 18 Oct 2024 06:36:10 -0700 Subject: [PATCH 65/82] [#32847][prism] Add Github Action for Prism as a Python precommit (#32845) * Add Github Action for Prism as a Python precommit * Update the execution condition. --------- Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> --- .../workflows/beam_PreCommit_Prism_Python.yml | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 .github/workflows/beam_PreCommit_Prism_Python.yml diff --git a/.github/workflows/beam_PreCommit_Prism_Python.yml b/.github/workflows/beam_PreCommit_Prism_Python.yml new file mode 100644 index 000000000000..5eb26d139ef5 --- /dev/null +++ b/.github/workflows/beam_PreCommit_Prism_Python.yml @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Prism Python + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - '.github/workflows/beam_PreCommit_Prism_Python.yml' + pull_request_target: + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Prism_Python.json' + issue_comment: + types: [created] + schedule: + - cron: '30 2/6 * * *' + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PreCommit_Prism_Python: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + timeout-minutes: 120 + runs-on: ['self-hosted', ubuntu-20.04, main] + strategy: + fail-fast: false + matrix: + job_name: ['beam_PreCommit_Prism_Python'] + job_phrase: ['Run Prism_Python PreCommit'] + python_version: ['3.9', '3.12'] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + startsWith(github.event.comment.body, 'Run Prism_Python PreCommit') + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} ${{ matrix.python_version }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: default + python-version: | + ${{ matrix.python_version }} + 3.9 + - name: Set PY_VER_CLEAN + id: set_py_ver_clean + run: | + PY_VER=${{ matrix.python_version }} + PY_VER_CLEAN=${PY_VER//.} + echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT + - name: run Prism Python Validates Runner script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:test-suites:portable:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:prismValidatesRunner \ No newline at end of file From 4a4da907003ec2eaca2824179fc79a46e28d576a Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Fri, 18 Oct 2024 10:18:58 -0400 Subject: [PATCH 66/82] Follow up website and change.md after 2.60 release (#32853) * Update release date in CHANGE.md * Update latest version for website --- CHANGES.md | 2 +- website/www/site/config.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 766f74fc3be0..6e589c318dd9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -97,7 +97,7 @@ * ([#X](https://github.com/apache/beam/issues/X)). -# [2.60.0] - Unreleased +# [2.60.0] - 2024-10-17 ## Highlights diff --git a/website/www/site/config.toml b/website/www/site/config.toml index e937289fbde7..d769f8434a7f 100644 --- a/website/www/site/config.toml +++ b/website/www/site/config.toml @@ -104,7 +104,7 @@ github_project_repo = "https://github.com/apache/beam" [params] description = "Apache Beam is an open source, unified model and set of language-specific SDKs for defining and executing data processing workflows, and also data ingestion and integration flows, supporting Enterprise Integration Patterns (EIPs) and Domain Specific Languages (DSLs). Dataflow pipelines simplify the mechanics of large-scale batch and streaming data processing and can run on a number of runtimes like Apache Flink, Apache Spark, and Google Cloud Dataflow (a cloud service). Beam also brings DSL in different languages, allowing users to easily implement their data integration processes." -release_latest = "2.59.0" +release_latest = "2.60.0" # The repository and branch where the files live in Github or Colab. This is used # to serve and stage from your local branch, but publish to the master branch. # e.g. https://github.com/{{< param branch_repo >}}/path/to/notebook.ipynb From 79528e17b5f3a959b9d52087eeb30a6fb4806f0f Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 18 Oct 2024 11:14:31 -0400 Subject: [PATCH 67/82] Add RAG to docs (#32859) --- .../ml/large-language-modeling.md | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/website/www/site/content/en/documentation/ml/large-language-modeling.md b/website/www/site/content/en/documentation/ml/large-language-modeling.md index 90bbd43383c0..b8bd0704d20e 100644 --- a/website/www/site/content/en/documentation/ml/large-language-modeling.md +++ b/website/www/site/content/en/documentation/ml/large-language-modeling.md @@ -170,3 +170,32 @@ class MyModelHandler(): def run_inference(self, batch: Sequence[str], model: MyWrapper, inference_args): return model.predict(unpickleable_object) ``` + +## RAG and Prompt Engineering in Beam + +Beam is also an excellent tool for improving the quality of your LLM prompts using Retrieval Augmented Generation (RAG). +Retrieval augmented generation is a technique that enhances large language models (LLMs) by connecting them to external knowledge sources. +This allows the LLM to access and process real-time information, improving the accuracy, relevance, and factuality of its responses. + +Beam has several mechanisms to make this process simpler: + +1. Beam's [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) provides an embeddings package to generate the embeddings used for RAG. You can also use RunInference to generate embeddings if you have a model without an embeddings handler. +2. Beam's [Enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) makes it easy to look up embeddings or other information in an external storage system like a [vector database](https://www.pinecone.io/learn/vector-database/). + +Collectively, you can use these to perform RAG using the following steps: + +**Pipeline 1 - generate knowledge base:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Write those embeddings to a vector DB using a [ParDo](https://beam.apache.org/documentation/programming-guide/#pardo) + +**Pipeline 2 - use knowledge base to perform RAG:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Enrich that data with additional embeddings from your vector DB using [Enrichment](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) +4. Use that enriched data to prompt your LLM with [RunInference](https://beam.apache.org/documentation/transforms/python/elementwise/runinference/) +5. Write that data to your desired sink using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) + +To view an example pipeline performing RAG, see https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/rag_usecase/beam_rag_notebook.ipynb From fa9eb2fe17f5f96b40275fe7b0a3981f4a52e0df Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 18 Oct 2024 11:15:02 -0400 Subject: [PATCH 68/82] Enrichment pydoc improvements (#32861) --- .../apache_beam/yaml/yaml_enrichment.py | 62 +++++++------------ 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index 00f2a5c1b1d1..9bea17f78fdd 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -48,7 +48,19 @@ def enrichment_transform( """ The Enrichment transform allows you to dynamically enhance elements in a pipeline by performing key-value - lookups against external services like APIs or databases. + lookups against external services like APIs or databases. + + Example Usage:: + + - type: Enrichment + config: + enrichment_handler: 'BigTable' + handler_config: + project_id: 'apache-beam-testing' + instance_id: 'beam-test' + table_id: 'bigtable-enrichment-test' + row_key: 'product_id' + timeout: 30 Args: enrichment_handler: Specifies the source from @@ -58,46 +70,14 @@ def enrichment_transform( "BigTable", "FeastFeatureStore", "VertexAIFeatureStore"]. handler_config: Specifies the parameters for - the respective enrichment_handler in a dictionary format. - BigQuery = ( - "BigQuery: " - "project, table_name, row_restriction_template, " - "fields, column_names, "condition_value_fn, " - "query_fn, min_batch_size, max_batch_size" - ) - - BigTable = ( - "BigTable: " - "project_id, instance_id, table_id, " - "row_key, row_filter, app_profile_id, " - "encoding, ow_key_fn, exception_level, include_timestamp" - ) - - FeastFeatureStore = ( - "FeastFeatureStore: " - "feature_store_yaml_path, feature_names, " - "feature_service_name, full_feature_names, " - "entity_row_fn, exception_level" - ) - - VertexAIFeatureStore = ( - "VertexAIFeatureStore: " - "project, location, api_endpoint, feature_store_name, " - "feature_view_name, row_key, exception_level" - ) - - Example Usage: - - - type: Enrichment - config: - enrichment_handler: 'BigTable' - handler_config: - project_id: 'apache-beam-testing' - instance_id: 'beam-test' - table_id: 'bigtable-enrichment-test' - row_key: 'product_id' - timeout: 30 - + the respective enrichment_handler in a dictionary format. + To see the full set of handler_config parameters, see + their corresponding doc pages: + + - :class:`~apache_beam.transforms.enrichment_handlers.bigquery.BigQueryEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.feast_feature_store.FeastFeatureStoreEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store.VertexAIFeatureStoreEnrichmentHandler` # pylint: disable=line-too-long """ options.YamlOptions.check_enabled(pcoll.pipeline, 'Enrichment') From 3fd3db30d07e1f25ec91df9dece707c371977a52 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 17 Oct 2024 08:26:39 -0400 Subject: [PATCH 69/82] Drop Flink 1.15 support --- CHANGES.md | 1 + gradle.properties | 2 +- .../runner-concepts/description.md | 8 ++-- runners/flink/1.15/build.gradle | 25 ----------- .../1.15/job-server-container/build.gradle | 26 ----------- runners/flink/1.15/job-server/build.gradle | 31 ------------- .../types/CoderTypeSerializer.java | 0 .../streaming/MemoryStateBackendWrapper.java | 0 .../flink/streaming/StreamSources.java | 0 runners/flink/flink_runner.gradle | 43 ++++++------------- .../src/apache_beam/runners/flink.ts | 2 +- settings.gradle.kts | 4 -- .../content/en/documentation/runners/flink.md | 3 +- 13 files changed, 21 insertions(+), 124 deletions(-) delete mode 100644 runners/flink/1.15/build.gradle delete mode 100644 runners/flink/1.15/job-server-container/build.gradle delete mode 100644 runners/flink/1.15/job-server/build.gradle rename runners/flink/{1.15 => 1.16}/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java (100%) rename runners/flink/{1.15 => 1.16}/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java (100%) rename runners/flink/{1.15 => 1.16}/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java (100%) diff --git a/CHANGES.md b/CHANGES.md index 30f904d7733a..2ccaeeb49f7e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -82,6 +82,7 @@ ## Deprecations +* Removed support for Flink 1.15 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). ## Bugfixes diff --git a/gradle.properties b/gradle.properties index f6e143690a34..868c7501ac31 100644 --- a/gradle.properties +++ b/gradle.properties @@ -39,6 +39,6 @@ docker_image_default_repo_root=apache docker_image_default_repo_prefix=beam_ # supported flink versions -flink_versions=1.15,1.16,1.17,1.18,1.19 +flink_versions=1.16,1.17,1.18,1.19 # supported python versions python_versions=3.8,3.9,3.10,3.11,3.12 diff --git a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md index 6eb1c04e966a..063e7f35f876 100644 --- a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md +++ b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md @@ -191,8 +191,8 @@ $ wordcount --input gs://dataflow-samples/shakespeare/kinglear.txt \ {{if (eq .Sdk "java")}} ##### Portable -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.15`, `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` @@ -233,8 +233,8 @@ mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount \ {{end}} {{if (eq .Sdk "python")}} -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.10`, `Flink 1.11`, `Flink 1.12`, `Flink 1.13`, `Flink 1.14`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` diff --git a/runners/flink/1.15/build.gradle b/runners/flink/1.15/build.gradle deleted file mode 100644 index 8055cf593ad0..000000000000 --- a/runners/flink/1.15/build.gradle +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -project.ext { - flink_major = '1.15' - flink_version = '1.15.0' -} - -// Load the main build script which contains all build logic. -apply from: "../flink_runner.gradle" diff --git a/runners/flink/1.15/job-server-container/build.gradle b/runners/flink/1.15/job-server-container/build.gradle deleted file mode 100644 index afdb68a0fc91..000000000000 --- a/runners/flink/1.15/job-server-container/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server-container' - -project.ext { - resource_path = basePath -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server_container.gradle" diff --git a/runners/flink/1.15/job-server/build.gradle b/runners/flink/1.15/job-server/build.gradle deleted file mode 100644 index 05ad8feb5b78..000000000000 --- a/runners/flink/1.15/job-server/build.gradle +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server' - -project.ext { - // Look for the source code in the parent module - main_source_dirs = ["$basePath/src/main/java"] - test_source_dirs = ["$basePath/src/test/java"] - main_resources_dirs = ["$basePath/src/main/resources"] - test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.15-job-server' -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server.gradle" diff --git a/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java similarity index 100% rename from runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java rename to runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java diff --git a/runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java similarity index 100% rename from runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java rename to runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java diff --git a/runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java similarity index 100% rename from runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java rename to runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index c8f492a901d3..d13e1c5faf6e 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -173,36 +173,19 @@ dependencies { implementation library.java.joda_time implementation library.java.args4j - // Flink 1.15 shades all remaining scala dependencies and therefor does not depend on a specific version of Scala anymore - if (flink_version.compareTo("1.15") >= 0) { - implementation "org.apache.flink:flink-clients:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" - - implementation "org.apache.flink:flink-streaming-java:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web:$flink_version" - } else { - implementation "org.apache.flink:flink-clients_2.12:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients_2.12:$flink_version" - - implementation "org.apache.flink:flink-streaming-java_2.12:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java_2.12:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils_2.12:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web_2.12:$flink_version" - } + implementation "org.apache.flink:flink-clients:$flink_version" + // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation + // configuration (https://issues.apache.org/jira/browse/BEAM-11732). + permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" + + implementation "org.apache.flink:flink-streaming-java:$flink_version" + // RocksDB state backend (included in the Flink distribution) + provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" + testImplementation "org.apache.flink:flink-test-utils:$flink_version" + + miniCluster "org.apache.flink:flink-runtime-web:$flink_version" implementation "org.apache.flink:flink-core:$flink_version" implementation "org.apache.flink:flink-metrics-core:$flink_version" diff --git a/sdks/typescript/src/apache_beam/runners/flink.ts b/sdks/typescript/src/apache_beam/runners/flink.ts index e21876c0d517..b68d3070a720 100644 --- a/sdks/typescript/src/apache_beam/runners/flink.ts +++ b/sdks/typescript/src/apache_beam/runners/flink.ts @@ -28,7 +28,7 @@ import { JavaJarService } from "../utils/service"; const MAGIC_HOST_NAMES = ["[local]", "[auto]"]; // These should stay in sync with gradle.properties. -const PUBLISHED_FLINK_VERSIONS = ["1.15", "1.16", "1.17", "1.18", "1.19"]; +const PUBLISHED_FLINK_VERSIONS = ["1.16", "1.17", "1.18", "1.19"]; const defaultOptions = { flinkMaster: "[local]", diff --git a/settings.gradle.kts b/settings.gradle.kts index b71ed1ede134..67e499e1ea31 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -125,10 +125,6 @@ include(":runners:extensions-java:metrics") * verify versions in website/www/site/content/en/documentation/runners/flink.md * verify version in sdks/python/apache_beam/runners/interactive/interactive_beam.py */ -// Flink 1.15 -include(":runners:flink:1.15") -include(":runners:flink:1.15:job-server") -include(":runners:flink:1.15:job-server-container") // Flink 1.16 include(":runners:flink:1.16") include(":runners:flink:1.16:job-server") diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 2c28aa7062ec..9bf99cf9e4c2 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -196,7 +196,6 @@ The optional `flink_version` option may be required as well for older versions o {{< paragraph class="language-portable" >}} Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: -[Flink 1.15](https://hub.docker.com/r/apache/beam_flink1.15_job_server). [Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). @@ -350,7 +349,7 @@ To find out which version of Flink is compatible with Beam please see the table 1.15.x beam-runners-flink-1.15 - ≥ 2.40.0 + 2.40.0 - 2.60.0 1.14.x From 56b54a4b66fbaafae57bcd5dd019ac0c183ee141 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 17 Oct 2024 13:58:07 -0400 Subject: [PATCH 70/82] Drop Flink 1.16 support --- CHANGES.md | 2 +- gradle.properties | 2 +- .../runner-concepts/description.md | 4 +- runners/flink/1.16/build.gradle | 25 --- .../1.16/job-server-container/build.gradle | 26 --- runners/flink/1.16/job-server/build.gradle | 31 --- .../types/CoderTypeSerializer.java | 195 ------------------ .../streaming/MemoryStateBackendWrapper.java | 0 .../flink/streaming/StreamSources.java | 0 .../src/apache_beam/runners/flink.ts | 2 +- settings.gradle.kts | 4 - .../content/en/documentation/runners/flink.md | 7 +- 12 files changed, 8 insertions(+), 290 deletions(-) delete mode 100644 runners/flink/1.16/build.gradle delete mode 100644 runners/flink/1.16/job-server-container/build.gradle delete mode 100644 runners/flink/1.16/job-server/build.gradle delete mode 100644 runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java rename runners/flink/{1.16 => 1.17}/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java (100%) rename runners/flink/{1.16 => 1.17}/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java (100%) diff --git a/CHANGES.md b/CHANGES.md index 2ccaeeb49f7e..0167b575f1de 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -82,7 +82,7 @@ ## Deprecations -* Removed support for Flink 1.15 +* Removed support for Flink 1.15 and 1.16 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). ## Bugfixes diff --git a/gradle.properties b/gradle.properties index 868c7501ac31..db1db368beb0 100644 --- a/gradle.properties +++ b/gradle.properties @@ -39,6 +39,6 @@ docker_image_default_repo_root=apache docker_image_default_repo_prefix=beam_ # supported flink versions -flink_versions=1.16,1.17,1.18,1.19 +flink_versions=1.17,1.18,1.19 # supported python versions python_versions=3.8,3.9,3.10,3.11,3.12 diff --git a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md index 063e7f35f876..c0d7b37725ac 100644 --- a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md +++ b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md @@ -191,7 +191,7 @@ $ wordcount --input gs://dataflow-samples/shakespeare/kinglear.txt \ {{if (eq .Sdk "java")}} ##### Portable -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. 2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: @@ -233,7 +233,7 @@ mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount \ {{end}} {{if (eq .Sdk "python")}} -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. 2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: diff --git a/runners/flink/1.16/build.gradle b/runners/flink/1.16/build.gradle deleted file mode 100644 index 21a222864a27..000000000000 --- a/runners/flink/1.16/build.gradle +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -project.ext { - flink_major = '1.16' - flink_version = '1.16.0' -} - -// Load the main build script which contains all build logic. -apply from: "../flink_runner.gradle" diff --git a/runners/flink/1.16/job-server-container/build.gradle b/runners/flink/1.16/job-server-container/build.gradle deleted file mode 100644 index afdb68a0fc91..000000000000 --- a/runners/flink/1.16/job-server-container/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server-container' - -project.ext { - resource_path = basePath -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server_container.gradle" diff --git a/runners/flink/1.16/job-server/build.gradle b/runners/flink/1.16/job-server/build.gradle deleted file mode 100644 index 99dc00275a0c..000000000000 --- a/runners/flink/1.16/job-server/build.gradle +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server' - -project.ext { - // Look for the source code in the parent module - main_source_dirs = ["$basePath/src/main/java"] - test_source_dirs = ["$basePath/src/test/java"] - main_resources_dirs = ["$basePath/src/main/resources"] - test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.16-job-server' -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server.gradle" diff --git a/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java deleted file mode 100644 index 956aad428d8b..000000000000 --- a/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.types; - -import java.io.EOFException; -import java.io.IOException; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; -import org.apache.beam.runners.flink.translation.wrappers.DataOutputViewWrapper; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.core.io.VersionedIOReadableWritable; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** - * Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for Beam {@link - * org.apache.beam.sdk.coders.Coder Coders}. - */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class CoderTypeSerializer extends TypeSerializer { - - private static final long serialVersionUID = 7247319138941746449L; - - private final Coder coder; - - /** - * {@link SerializablePipelineOptions} deserialization will cause {@link - * org.apache.beam.sdk.io.FileSystems} registration needed for {@link - * org.apache.beam.sdk.transforms.Reshuffle} translation. - */ - private final SerializablePipelineOptions pipelineOptions; - - private final boolean fasterCopy; - - public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { - Preconditions.checkNotNull(coder); - Preconditions.checkNotNull(pipelineOptions); - this.coder = coder; - this.pipelineOptions = pipelineOptions; - - FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); - this.fasterCopy = options.getFasterCopy(); - } - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public CoderTypeSerializer duplicate() { - return new CoderTypeSerializer<>(coder, pipelineOptions); - } - - @Override - public T createInstance() { - return null; - } - - @Override - public T copy(T t) { - if (fasterCopy) { - return t; - } - try { - return CoderUtils.clone(coder, t); - } catch (CoderException e) { - throw new RuntimeException("Could not clone.", e); - } - } - - @Override - public T copy(T t, T reuse) { - return copy(t); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize(T t, DataOutputView dataOutputView) throws IOException { - DataOutputViewWrapper outputWrapper = new DataOutputViewWrapper(dataOutputView); - coder.encode(t, outputWrapper); - } - - @Override - public T deserialize(DataInputView dataInputView) throws IOException { - try { - DataInputViewWrapper inputWrapper = new DataInputViewWrapper(dataInputView); - return coder.decode(inputWrapper); - } catch (CoderException e) { - Throwable cause = e.getCause(); - if (cause instanceof EOFException) { - throw (EOFException) cause; - } else { - throw e; - } - } - } - - @Override - public T deserialize(T t, DataInputView dataInputView) throws IOException { - return deserialize(dataInputView); - } - - @Override - public void copy(DataInputView dataInputView, DataOutputView dataOutputView) throws IOException { - serialize(deserialize(dataInputView), dataOutputView); - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - CoderTypeSerializer that = (CoderTypeSerializer) o; - return coder.equals(that.coder); - } - - @Override - public int hashCode() { - return coder.hashCode(); - } - - @Override - public TypeSerializerSnapshot snapshotConfiguration() { - return new UnversionedTypeSerializerSnapshot<>(this); - } - - /** - * A legacy snapshot which does not care about schema compatibility. This is used only for state - * restore of state created by Beam 2.54.0 and below for Flink 1.16 and below. - */ - public static class LegacySnapshot extends TypeSerializerConfigSnapshot { - - /** Needs to be public to work with {@link VersionedIOReadableWritable}. */ - public LegacySnapshot() {} - - public LegacySnapshot(CoderTypeSerializer serializer) { - setPriorSerializer(serializer); - } - - @Override - public int getVersion() { - // We always return the same version - return 1; - } - - @Override - public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( - TypeSerializer newSerializer) { - - // We assume compatibility because we don't have a way of checking schema compatibility - return TypeSerializerSchemaCompatibility.compatibleAsIs(); - } - } - - @Override - public String toString() { - return "CoderTypeSerializer{" + "coder=" + coder + '}'; - } -} diff --git a/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java similarity index 100% rename from runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java rename to runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java diff --git a/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java similarity index 100% rename from runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java rename to runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java diff --git a/sdks/typescript/src/apache_beam/runners/flink.ts b/sdks/typescript/src/apache_beam/runners/flink.ts index b68d3070a720..ab2d641b3302 100644 --- a/sdks/typescript/src/apache_beam/runners/flink.ts +++ b/sdks/typescript/src/apache_beam/runners/flink.ts @@ -28,7 +28,7 @@ import { JavaJarService } from "../utils/service"; const MAGIC_HOST_NAMES = ["[local]", "[auto]"]; // These should stay in sync with gradle.properties. -const PUBLISHED_FLINK_VERSIONS = ["1.16", "1.17", "1.18", "1.19"]; +const PUBLISHED_FLINK_VERSIONS = ["1.17", "1.18", "1.19"]; const defaultOptions = { flinkMaster: "[local]", diff --git a/settings.gradle.kts b/settings.gradle.kts index 67e499e1ea31..a38f69dac09e 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -125,10 +125,6 @@ include(":runners:extensions-java:metrics") * verify versions in website/www/site/content/en/documentation/runners/flink.md * verify version in sdks/python/apache_beam/runners/interactive/interactive_beam.py */ -// Flink 1.16 -include(":runners:flink:1.16") -include(":runners:flink:1.16:job-server") -include(":runners:flink:1.16:job-server-container") // Flink 1.17 include(":runners:flink:1.17") include(":runners:flink:1.17:job-server") diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 9bf99cf9e4c2..fb897805cfd6 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -196,7 +196,6 @@ The optional `flink_version` option may be required as well for older versions o {{< paragraph class="language-portable" >}} Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: -[Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). [Flink 1.19](https://hub.docker.com/r/apache/beam_flink1.19_job_server). @@ -312,8 +311,8 @@ reference. ## Flink Version Compatibility The Flink cluster version has to match the minor version used by the FlinkRunner. -The minor version is the first two numbers in the version string, e.g. in `1.16.0` the -minor version is `1.16`. +The minor version is the first two numbers in the version string, e.g. in `1.19.0` the +minor version is `1.19`. We try to track the latest version of Apache Flink at the time of the Beam release. A Flink version is supported by Beam for the time it is supported by the Flink community. @@ -344,7 +343,7 @@ To find out which version of Flink is compatible with Beam please see the table 1.16.x beam-runners-flink-1.16 - ≥ 2.47.0 + 2.47.0 - 2.60.0 1.15.x From dfa54b23e8d4143275f4bd2c0f90d85944ae76ee Mon Sep 17 00:00:00 2001 From: Jan Lukavsky Date: Fri, 18 Oct 2024 09:50:48 +0200 Subject: [PATCH 71/82] [flink] #32838 remove removed flink version references --- sdks/go/examples/wasm/README.md | 2 +- sdks/python/apache_beam/options/pipeline_options.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/go/examples/wasm/README.md b/sdks/go/examples/wasm/README.md index a78649134305..103bef88642b 100644 --- a/sdks/go/examples/wasm/README.md +++ b/sdks/go/examples/wasm/README.md @@ -68,7 +68,7 @@ cd $BEAM_HOME Expected output should include the following, from which you acquire the latest flink runner version. ```shell -'flink_versions: 1.15,1.16,1.17,1.18,1.19' +'flink_versions: 1.17,1.18,1.19' ``` #### 2. Set to the latest flink runner version i.e. 1.16 diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 837dc0f5439f..455d12b4d3c1 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1679,7 +1679,7 @@ def _add_argparse_args(cls, parser): class FlinkRunnerOptions(PipelineOptions): # These should stay in sync with gradle.properties. - PUBLISHED_FLINK_VERSIONS = ['1.15', '1.16', '1.17', '1.18', '1.19'] + PUBLISHED_FLINK_VERSIONS = ['1.17', '1.18', '1.19'] @classmethod def _add_argparse_args(cls, parser): From ef6caf4c65c09990e169311750804856a734e5bd Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 18 Oct 2024 14:10:14 -0700 Subject: [PATCH 72/82] Added a TODO. --- sdks/python/apache_beam/yaml/tests/map.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/yaml/tests/map.yaml b/sdks/python/apache_beam/yaml/tests/map.yaml index 31fb442085fb..bbb7fc4527de 100644 --- a/sdks/python/apache_beam/yaml/tests/map.yaml +++ b/sdks/python/apache_beam/yaml/tests/map.yaml @@ -30,6 +30,7 @@ pipelines: config: append: true fields: + # TODO(https://github.com/apache/beam/issues/32832): Figure out why Java sometimes re-orders these fields. literal_int: 10 named_field: element literal_float: 1.5 From eafe08b7d9c4bb28695d27d9933ff144cd657714 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 18 Oct 2024 14:44:57 -0700 Subject: [PATCH 73/82] Update docs on error handling output schema. --- .../www/site/content/en/documentation/sdks/yaml-errors.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/website/www/site/content/en/documentation/sdks/yaml-errors.md b/website/www/site/content/en/documentation/sdks/yaml-errors.md index 8c0d9f06ade3..903e18d6b3c7 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-errors.md +++ b/website/www/site/content/en/documentation/sdks/yaml-errors.md @@ -37,7 +37,8 @@ The `output` parameter is a name that must referenced as an input to another transform that will process the errors (e.g. by writing them out). For example, the following code will write all "good" processed records to one file and -any "bad" records to a separate file. +any "bad" records, along with metadata about what error was encountered, +to a separate file. ``` pipeline: @@ -77,6 +78,8 @@ for a robust pipeline). Note also that the exact format of the error outputs is still being finalized. They can be safely printed and written to outputs, but their precise schema may change in a future version of Beam and should not yet be depended on. +Currently it has, at the very least, an `element` field which holds the element +that caused the error. Some transforms allow for extra arguments in their error_handling config, e.g. for Python functions one can give a `threshold` which limits the relative number From 1ba33b888fc76a1e25cd7ad45ec9fde642b6f572 Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Sun, 20 Oct 2024 08:39:04 -0700 Subject: [PATCH 74/82] [#30703][prism] Update logging handling (#32826) * Migrate to standard library slog package * Add dev logger dependency for pre printed development logs * Improve logging output for prism and user side logs, and emit container logs. * Fix missed lines from artifact and worker. --------- Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> --- sdks/go.mod | 4 +- sdks/go.sum | 2 + sdks/go/cmd/prism/prism.go | 47 ++++++++++++++ .../beam/core/runtime/metricsx/metricsx.go | 2 +- .../pkg/beam/runners/prism/internal/coders.go | 2 +- .../runners/prism/internal/engine/data.go | 2 +- .../prism/internal/engine/elementmanager.go | 4 +- .../runners/prism/internal/environments.go | 32 ++++++++-- .../beam/runners/prism/internal/execute.go | 14 ++-- .../runners/prism/internal/handlerunner.go | 2 +- .../prism/internal/jobservices/artifact.go | 4 +- .../runners/prism/internal/jobservices/job.go | 4 +- .../prism/internal/jobservices/management.go | 3 +- .../prism/internal/jobservices/metrics.go | 4 +- .../prism/internal/jobservices/server.go | 6 +- .../beam/runners/prism/internal/preprocess.go | 2 +- .../runners/prism/internal/separate_test.go | 12 ++-- .../pkg/beam/runners/prism/internal/stage.go | 4 +- .../beam/runners/prism/internal/web/web.go | 2 +- .../runners/prism/internal/worker/bundle.go | 2 +- .../runners/prism/internal/worker/worker.go | 64 +++++++++++-------- 21 files changed, 153 insertions(+), 65 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 74556ee12a55..706be73f97f6 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -20,7 +20,7 @@ // directory. module github.com/apache/beam/sdks/v2 -go 1.21 +go 1.21.0 require ( cloud.google.com/go/bigquery v1.63.1 @@ -69,6 +69,8 @@ require ( require ( github.com/avast/retry-go/v4 v4.6.0 github.com/fsouza/fake-gcs-server v1.49.2 + github.com/golang-cz/devslog v0.0.11 + github.com/golang/protobuf v1.5.4 golang.org/x/exp v0.0.0-20231006140011-7918f672742d ) diff --git a/sdks/go.sum b/sdks/go.sum index af68a630addd..fa3c75bd3395 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -853,6 +853,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-cz/devslog v0.0.11 h1:v4Yb9o0ZpuZ/D8ZrtVw1f9q5XrjnkxwHF1XmWwO8IHg= +github.com/golang-cz/devslog v0.0.11/go.mod h1:bSe5bm0A7Nyfqtijf1OMNgVJHlWEuVSXnkuASiE1vV8= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= diff --git a/sdks/go/cmd/prism/prism.go b/sdks/go/cmd/prism/prism.go index 39c19df00dc3..070d2f023b74 100644 --- a/sdks/go/cmd/prism/prism.go +++ b/sdks/go/cmd/prism/prism.go @@ -22,9 +22,14 @@ import ( "flag" "fmt" "log" + "log/slog" + "os" + "strings" + "time" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism" + "github.com/golang-cz/devslog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -37,10 +42,52 @@ var ( idleShutdownTimeout = flag.Duration("idle_shutdown_timeout", -1, "duration that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Defaults to never shutting down.") ) +// Logging flags +var ( + debug = flag.Bool("debug", false, + "Enables full verbosity debug logging from the runner by default. Used to build SDKs or debug Prism itself.") + logKind = flag.String("log_kind", "dev", + "Determines the format of prism's logging to std err: valid values are `dev', 'json', or 'text'. Default is `dev`.") +) + +var logLevel = new(slog.LevelVar) + func main() { flag.Parse() ctx, cancel := context.WithCancelCause(context.Background()) + var logHandler slog.Handler + loggerOutput := os.Stderr + handlerOpts := &slog.HandlerOptions{ + Level: logLevel, + AddSource: *debug, + } + if *debug { + logLevel.Set(slog.LevelDebug) + // Print the Prism source line for a log in debug mode. + handlerOpts.AddSource = true + } + switch strings.ToLower(*logKind) { + case "dev": + logHandler = + devslog.NewHandler(loggerOutput, &devslog.Options{ + TimeFormat: "[" + time.RFC3339Nano + "]", + StringerFormatter: true, + HandlerOptions: handlerOpts, + StringIndentation: false, + NewLineAfterLog: true, + MaxErrorStackTrace: 3, + }) + case "json": + logHandler = slog.NewJSONHandler(loggerOutput, handlerOpts) + case "text": + logHandler = slog.NewTextHandler(loggerOutput, handlerOpts) + default: + log.Fatalf("Invalid value for log_kind: %v, must be 'dev', 'json', or 'text'", *logKind) + } + + slog.SetDefault(slog.New(logHandler)) + cli, err := makeJobClient(ctx, prism.Options{ Port: *jobPort, diff --git a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go index c71ead208364..06bb727178fc 100644 --- a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go +++ b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go @@ -19,12 +19,12 @@ import ( "bytes" "fmt" "log" + "log/slog" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" - "golang.org/x/exp/slog" ) // FromMonitoringInfos extracts metrics from monitored states and diff --git a/sdks/go/pkg/beam/runners/prism/internal/coders.go b/sdks/go/pkg/beam/runners/prism/internal/coders.go index eb8abe16ecf8..ffea90e79065 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/coders.go +++ b/sdks/go/pkg/beam/runners/prism/internal/coders.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -28,7 +29,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index eaaf7f831712..7b8689f95112 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -18,12 +18,12 @@ package engine import ( "bytes" "fmt" + "log/slog" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" - "golang.org/x/exp/slog" ) // StateData is a "union" between Bag state and MultiMap state to increase common code. diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index f7229853e4d3..3cfde4701a8f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "io" + "log/slog" "sort" "strings" "sync" @@ -36,7 +37,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" ) type element struct { @@ -1607,7 +1607,7 @@ func (ss *stageState) bundleReady(em *ElementManager, emNow mtime.Time) (mtime.T inputW := ss.input _, upstreamW := ss.UpstreamWatermark() if inputW == upstreamW { - slog.Debug("bundleReady: insufficient upstream watermark", + slog.Debug("bundleReady: unchanged upstream watermark", slog.String("stage", ss.ID), slog.Group("watermark", slog.Any("upstream", upstreamW), diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go index add7f769a702..2f960a04f0cb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/environments.go +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "os" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" @@ -27,7 +28,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" @@ -42,7 +42,7 @@ import ( // TODO move environment handling to the worker package. func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) error { - logger := slog.With(slog.String("envID", wk.Env)) + logger := j.Logger.With(slog.String("envID", wk.Env)) // TODO fix broken abstraction. // We're starting a worker pool here, because that's the loopback environment. // It's sort of a mess, largely because of loopback, which has @@ -56,7 +56,7 @@ func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *wor } go func() { externalEnvironment(ctx, ep, wk) - slog.Debug("environment stopped", slog.String("job", j.String())) + logger.Debug("environment stopped", slog.String("job", j.String())) }() return nil case urns.EnvDocker: @@ -129,6 +129,8 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock credEnv := fmt.Sprintf("%v=%v", gcloudCredsEnv, dockerGcloudCredsFile) envs = append(envs, credEnv) } + } else { + logger.Debug("local GCP credentials environment variable not found") } if _, _, err := cli.ImageInspectWithRaw(ctx, dp.GetContainerImage()); err != nil { // We don't have a local image, so we should pull it. @@ -140,6 +142,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock logger.Warn("unable to pull image and it's not local", "error", err) } } + logger.Debug("creating container", "envs", envs, "mounts", mounts) ccr, err := cli.ContainerCreate(ctx, &container.Config{ Image: dp.GetContainerImage(), @@ -169,17 +172,32 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock return fmt.Errorf("unable to start container image %v with docker for env %v, err: %w", dp.GetContainerImage(), wk.Env, err) } + logger.Debug("container started") + // Start goroutine to wait on container state. go func() { defer cli.Close() defer wk.Stop() + defer func() { + logger.Debug("container stopped") + }() - statusCh, errCh := cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) + bgctx := context.Background() + statusCh, errCh := cli.ContainerWait(bgctx, containerID, container.WaitConditionNotRunning) select { case <-ctx.Done(): - // Can't use command context, since it's already canceled here. - err := cli.ContainerKill(context.Background(), containerID, "") + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { + logger.Error("error fetching container logs error on context cancellation", "error", err) + } + if rc != nil { + defer rc.Close() + var buf bytes.Buffer + stdcopy.StdCopy(&buf, &buf, rc) + logger.Info("container being killed", slog.Any("cause", context.Cause(ctx)), slog.Any("containerLog", buf)) + } + // Can't use command context, since it's already canceled here. + if err := cli.ContainerKill(bgctx, containerID, ""); err != nil { logger.Error("docker container kill error", "error", err) } case err := <-errCh: @@ -189,7 +207,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock case resp := <-statusCh: logger.Info("docker container has self terminated", "status_code", resp.StatusCode) - rc, err := cli.ContainerLogs(ctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { logger.Error("docker container logs error", "error", err) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index d7605f34f5f2..614edee47721 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "log/slog" "sort" "sync/atomic" "time" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) @@ -311,7 +311,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err) } stages[stage.ID] = stage - slog.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) + j.Logger.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) outputs := maps.Keys(stage.OutputsToCoders) sort.Strings(outputs) em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) @@ -322,9 +322,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers) } default: - err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) - slog.Error("Execute", err) - return err + return fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) } } @@ -344,11 +342,13 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic for { select { case <-ctx.Done(): - return context.Cause(ctx) + err := context.Cause(ctx) + j.Logger.Debug("context canceled", slog.Any("cause", err)) + return err case rb, ok := <-bundles: if !ok { err := eg.Wait() - slog.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err)) + j.Logger.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err), slog.Any("topo", topo)) return err } eg.Go(func() error { diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go index 8590fd0d4ced..be9d39ad02b7 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "io" + "log/slog" "reflect" "sort" @@ -31,7 +32,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go index 99b786d45980..e42e3e7ca666 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go @@ -20,9 +20,9 @@ import ( "context" "fmt" "io" + "log/slog" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) @@ -77,7 +77,7 @@ func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingSer case *jobpb.ArtifactResponseWrapper_ResolveArtifactResponse: err := fmt.Errorf("unexpected ResolveArtifactResponse to GetArtifact: %v", in.GetResponse()) - slog.Error("GetArtifact failure", err) + slog.Error("GetArtifact failure", slog.Any("error", err)) return err } } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index 1407feafe325..deef259a99d1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -27,6 +27,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "sort" "strings" "sync" @@ -37,7 +38,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/types/known/structpb" ) @@ -88,6 +88,8 @@ type Job struct { // Context used to terminate this job. RootCtx context.Context CancelFn context.CancelCauseFunc + // Logger for this job. + Logger *slog.Logger metrics metricsStore } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index b957b99ca63d..a2840760bf7a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -19,6 +19,7 @@ import ( "context" "errors" "fmt" + "log/slog" "sync" "sync/atomic" @@ -27,7 +28,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -92,6 +92,7 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * cancelFn(err) terminalOnceWrap() }, + Logger: s.logger, // TODO substitute with a configured logger. artifactEndpoint: s.Endpoint(), } // Stop the idle timer when a new job appears. diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go index 03d5b0a98369..bbbdfd1eba4f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "hash/maphash" + "log/slog" "math" "sort" "sync" @@ -28,7 +29,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/constraints" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -589,7 +589,7 @@ func (m *metricsStore) AddShortIDs(resp *fnpb.MonitoringInfosMetadataResponse) { urn := mi.GetUrn() ops, ok := mUrn2Ops[urn] if !ok { - slog.Debug("unknown metrics urn", slog.String("urn", urn)) + slog.Debug("unknown metrics urn", slog.String("shortID", short), slog.String("urn", urn), slog.String("type", mi.Type)) continue } key := ops.keyFn(urn, mi.GetLabels()) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index 320159f54c06..bdfe2aff2dd4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -18,6 +18,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "math" "net" "os" @@ -27,7 +28,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/grpc" ) @@ -53,6 +53,7 @@ type Server struct { terminatedJobCount uint32 // Use with atomics. idleTimeout time.Duration cancelFn context.CancelCauseFunc + logger *slog.Logger // execute defines how a job is executed. execute func(*Job) @@ -71,8 +72,9 @@ func NewServer(port int, execute func(*Job)) *Server { lis: lis, jobs: make(map[string]*Job), execute: execute, + logger: slog.Default(), // TODO substitute with a configured logger. } - slog.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) + s.logger.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) opts := []grpc.ServerOption{ grpc.MaxRecvMsgSize(math.MaxInt32), } diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index 7de32f85b7ee..dceaa9ab8fcb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -17,6 +17,7 @@ package internal import ( "fmt" + "log/slog" "sort" "strings" @@ -26,7 +27,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go index 1be3d3e70841..650932f525c8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go @@ -18,6 +18,7 @@ package internal_test import ( "context" "fmt" + "log/slog" "net" "net/http" "net/rpc" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/stats" - "golang.org/x/exp/slog" ) // separate_test.go retains structures and tests to ensure the runner can @@ -286,7 +286,7 @@ func (ws *Watchers) Check(args *Args, unblocked *bool) error { w.mu.Lock() *unblocked = w.sentinelCount >= w.sentinelCap w.mu.Unlock() - slog.Debug("sentinel target for watcher%d is %d/%d. unblocked=%v", args.WatcherID, w.sentinelCount, w.sentinelCap, *unblocked) + slog.Debug("sentinel watcher status", slog.Int("watcher", args.WatcherID), slog.Int("sentinelCount", w.sentinelCount), slog.Int("sentinelCap", w.sentinelCap), slog.Bool("unblocked", *unblocked)) return nil } @@ -360,7 +360,7 @@ func (fn *sepHarnessBase) setup() error { sepClientOnce.Do(func() { client, err := rpc.DialHTTP("tcp", fn.LocalService) if err != nil { - slog.Error("failed to dial sentinels server", err, slog.String("endpoint", fn.LocalService)) + slog.Error("failed to dial sentinels server", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(fmt.Sprintf("dialing sentinels server %v: %v", fn.LocalService, err)) } sepClient = client @@ -385,7 +385,7 @@ func (fn *sepHarnessBase) setup() error { var unblock bool err := sepClient.Call("Watchers.Check", &Args{WatcherID: id}, &unblock) if err != nil { - slog.Error("Watchers.Check: sentinels server error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Check: sentinels server error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic("sentinel server error") } if unblock { @@ -406,7 +406,7 @@ func (fn *sepHarnessBase) block() { var ignored bool err := sepClient.Call("Watchers.Block", &Args{WatcherID: fn.WatcherID}, &ignored) if err != nil { - slog.Error("Watchers.Block error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Block error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(err) } c := sepWaitMap[fn.WatcherID] @@ -423,7 +423,7 @@ func (fn *sepHarnessBase) delay() bool { var delay bool err := sepClient.Call("Watchers.Delay", &Args{WatcherID: fn.WatcherID}, &delay) if err != nil { - slog.Error("Watchers.Delay error", err) + slog.Error("Watchers.Delay error", slog.Any("error", err)) panic(err) } return delay diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index f33754b2ca0a..9f00c22789b6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "runtime/debug" "sync/atomic" "time" @@ -33,7 +34,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) @@ -361,7 +361,7 @@ func portFor(wInCid string, wk *worker.W) []byte { } sourcePortBytes, err := proto.Marshal(sourcePort) if err != nil { - slog.Error("bad port", err, slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) + slog.Error("bad port", slog.Any("error", err), slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) } return sourcePortBytes } diff --git a/sdks/go/pkg/beam/runners/prism/internal/web/web.go b/sdks/go/pkg/beam/runners/prism/internal/web/web.go index 9fabe22cee3a..b14778e4462c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/web/web.go +++ b/sdks/go/pkg/beam/runners/prism/internal/web/web.go @@ -26,6 +26,7 @@ import ( "fmt" "html/template" "io" + "log/slog" "net/http" "sort" "strings" @@ -40,7 +41,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 3ccafdb81e9a..55cdb97f258c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -19,12 +19,12 @@ import ( "bytes" "context" "fmt" + "log/slog" "sync/atomic" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" - "golang.org/x/exp/slog" ) // SideInputKey is for data lookups for a given bundle. diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index f9ec03793488..1f129595abef 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -22,10 +22,9 @@ import ( "context" "fmt" "io" + "log/slog" "math" "net" - "strconv" - "strings" "sync" "sync/atomic" "time" @@ -39,7 +38,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -203,30 +201,46 @@ func (wk *W) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { case codes.Canceled: return nil default: - slog.Error("logging.Recv", err, "worker", wk) + slog.Error("logging.Recv", slog.Any("error", err), slog.Any("worker", wk)) return err } } for _, l := range in.GetLogEntries() { - if l.Severity >= minsev { - // TODO: Connect to the associated Job for this worker instead of - // logging locally for SDK side logging. - file := l.GetLogLocation() - i := strings.LastIndex(file, ":") - line, _ := strconv.Atoi(file[i+1:]) - if i > 0 { - file = file[:i] - } + // TODO base this on a per pipeline logging setting. + if l.Severity < minsev { + continue + } + + // Ideally we'd be writing these to per-pipeline files, but for now re-log them on the Prism process. + // We indicate they're from the SDK, and which worker, keeping the same log severity. + // SDK specific and worker specific fields are in separate groups for legibility. - slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), l.GetMessage(), - slog.Any(slog.SourceKey, &slog.Source{ - File: file, - Line: line, - }), - slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), - slog.Any("worker", wk), - ) + attrs := []any{ + slog.String("transformID", l.GetTransformId()), // TODO: pull the unique name from the pipeline graph. + slog.String("location", l.GetLogLocation()), + slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), + slog.String(slog.MessageKey, l.GetMessage()), } + if fs := l.GetCustomData().GetFields(); len(fs) > 0 { + var grp []any + for n, v := range l.GetCustomData().GetFields() { + var attr slog.Attr + switch v.Kind.(type) { + case *structpb.Value_BoolValue: + attr = slog.Bool(n, v.GetBoolValue()) + case *structpb.Value_NumberValue: + attr = slog.Float64(n, v.GetNumberValue()) + case *structpb.Value_StringValue: + attr = slog.String(n, v.GetStringValue()) + default: + attr = slog.Any(n, v.AsInterface()) + } + grp = append(grp, attr) + } + attrs = append(attrs, slog.Group("customData", grp...)) + } + + slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), "log from SDK worker", slog.Any("worker", wk), slog.Group("sdk", attrs...)) } } } @@ -298,7 +312,7 @@ func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { if b, ok := wk.activeInstructions[resp.GetInstructionId()]; ok { b.Respond(resp) } else { - slog.Debug("ctrl.Recv: %v", resp) + slog.Debug("ctrl.Recv", slog.Any("response", resp)) } wk.mu.Unlock() } @@ -355,7 +369,7 @@ func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { case codes.Canceled: return default: - slog.Error("data.Recv failed", err, "worker", wk) + slog.Error("data.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -434,7 +448,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case codes.Canceled: return default: - slog.Error("state.Recv failed", err, "worker", wk) + slog.Error("state.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -584,7 +598,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { }() for resp := range responses { if err := state.Send(resp); err != nil { - slog.Error("state.Send error", err) + slog.Error("state.Send", slog.Any("error", err)) } } return nil From ac87d7b48e86e0c3e863d13b5e8d52469134a446 Mon Sep 17 00:00:00 2001 From: Thiago Nunes Date: Mon, 21 Oct 2024 18:51:09 +1100 Subject: [PATCH 75/82] fix: generate random index name for change streams (#32689) Generates index names for change stream partition metadata table using a random UUID. This prevents issues if the job is being redeployed in an existing database. --- .../beam/sdk/io/gcp/spanner/SpannerIO.java | 19 ++- .../spanner/changestreams/NameGenerator.java | 52 ------- .../spanner/changestreams/dao/DaoFactory.java | 12 +- .../dao/PartitionMetadataAdminDao.java | 58 +++---- .../dao/PartitionMetadataDao.java | 35 +++++ .../dao/PartitionMetadataTableNames.java | 144 ++++++++++++++++++ .../dofn/CleanUpReadChangeStreamDoFn.java | 4 +- .../changestreams/dofn/InitializeDoFn.java | 1 + .../changestreams/NameGeneratorTest.java | 41 ----- .../dao/PartitionMetadataAdminDaoTest.java | 56 +++++-- .../dao/PartitionMetadataTableNamesTest.java | 73 +++++++++ 11 files changed, 344 insertions(+), 151 deletions(-) delete mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java delete mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 435bbba9ae8e..d9dde11a3081 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -25,7 +25,6 @@ import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.DEFAULT_RPC_PRIORITY; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.MAX_INCLUSIVE_END_AT; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.THROUGHPUT_WINDOW_SECONDS; -import static org.apache.beam.sdk.io.gcp.spanner.changestreams.NameGenerator.generatePartitionMetadataTableName; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -61,6 +60,7 @@ import java.util.HashMap; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -77,6 +77,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.action.ActionFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.CleanUpReadChangeStreamDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.DetectNewPartitionsDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.InitializeDoFn; @@ -1772,9 +1773,13 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta + fullPartitionMetadataDatabaseId + " has dialect " + metadataDatabaseDialect); - final String partitionMetadataTableName = - MoreObjects.firstNonNull( - getMetadataTable(), generatePartitionMetadataTableName(partitionMetadataDatabaseId)); + PartitionMetadataTableNames partitionMetadataTableNames = + Optional.ofNullable(getMetadataTable()) + .map( + table -> + PartitionMetadataTableNames.fromExistingTable( + partitionMetadataDatabaseId, table)) + .orElse(PartitionMetadataTableNames.generateRandom(partitionMetadataDatabaseId)); final String changeStreamName = getChangeStreamName(); final Timestamp startTimestamp = getInclusiveStartAt(); // Uses (Timestamp.MAX - 1ns) at max for end timestamp, because we add 1ns to transform the @@ -1791,7 +1796,7 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta changeStreamSpannerConfig, changeStreamName, partitionMetadataSpannerConfig, - partitionMetadataTableName, + partitionMetadataTableNames, rpcPriority, input.getPipeline().getOptions().getJobName(), changeStreamDatabaseDialect, @@ -1807,7 +1812,9 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta final PostProcessingMetricsDoFn postProcessingMetricsDoFn = new PostProcessingMetricsDoFn(metrics); - LOG.info("Partition metadata table that will be used is " + partitionMetadataTableName); + LOG.info( + "Partition metadata table that will be used is " + + partitionMetadataTableNames.getTableName()); final PCollection impulseOut = input.apply(Impulse.create()); final PCollection partitionsOut = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java deleted file mode 100644 index 322e85cb07a2..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; - -import java.util.UUID; - -/** - * This class generates a unique name for the partition metadata table, which is created when the - * Connector is initialized. - */ -public class NameGenerator { - - private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; - private static final int MAX_TABLE_NAME_LENGTH = 63; - - /** - * Generates an unique name for the partition metadata table in the form of {@code - * "Metadata__"}. - * - * @param databaseId The database id where the table will be created - * @return the unique generated name of the partition metadata table - */ - public static String generatePartitionMetadataTableName(String databaseId) { - // There are 11 characters in the name format. - // Maximum Spanner database ID length is 30 characters. - // UUID always generates a String with 36 characters. - // Since the Postgres table name length is 63, we may need to truncate the table name depending - // on the database length. - String fullString = - String.format(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, UUID.randomUUID()) - .replaceAll("-", "_"); - if (fullString.length() < MAX_TABLE_NAME_LENGTH) { - return fullString; - } - return fullString.substring(0, MAX_TABLE_NAME_LENGTH); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java index b9718fdb675e..787abad02e02 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java @@ -44,7 +44,7 @@ public class DaoFactory implements Serializable { private final SpannerConfig metadataSpannerConfig; private final String changeStreamName; - private final String partitionMetadataTableName; + private final PartitionMetadataTableNames partitionMetadataTableNames; private final RpcPriority rpcPriority; private final String jobName; private final Dialect spannerChangeStreamDatabaseDialect; @@ -56,7 +56,7 @@ public class DaoFactory implements Serializable { * @param changeStreamSpannerConfig the configuration for the change streams DAO * @param changeStreamName the name of the change stream for the change streams DAO * @param metadataSpannerConfig the metadata tables configuration - * @param partitionMetadataTableName the name of the created partition metadata table + * @param partitionMetadataTableNames the names of the partition metadata ddl objects * @param rpcPriority the priority of the requests made by the DAO queries * @param jobName the name of the running job */ @@ -64,7 +64,7 @@ public DaoFactory( SpannerConfig changeStreamSpannerConfig, String changeStreamName, SpannerConfig metadataSpannerConfig, - String partitionMetadataTableName, + PartitionMetadataTableNames partitionMetadataTableNames, RpcPriority rpcPriority, String jobName, Dialect spannerChangeStreamDatabaseDialect, @@ -78,7 +78,7 @@ public DaoFactory( this.changeStreamSpannerConfig = changeStreamSpannerConfig; this.changeStreamName = changeStreamName; this.metadataSpannerConfig = metadataSpannerConfig; - this.partitionMetadataTableName = partitionMetadataTableName; + this.partitionMetadataTableNames = partitionMetadataTableNames; this.rpcPriority = rpcPriority; this.jobName = jobName; this.spannerChangeStreamDatabaseDialect = spannerChangeStreamDatabaseDialect; @@ -102,7 +102,7 @@ public synchronized PartitionMetadataAdminDao getPartitionMetadataAdminDao() { databaseAdminClient, metadataSpannerConfig.getInstanceId().get(), metadataSpannerConfig.getDatabaseId().get(), - partitionMetadataTableName, + partitionMetadataTableNames, this.metadataDatabaseDialect); } return partitionMetadataAdminDao; @@ -120,7 +120,7 @@ public synchronized PartitionMetadataDao getPartitionMetadataDao() { if (partitionMetadataDaoInstance == null) { partitionMetadataDaoInstance = new PartitionMetadataDao( - this.partitionMetadataTableName, + this.partitionMetadataTableNames.getTableName(), spannerAccessor.getDatabaseClient(), this.metadataDatabaseDialect); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java index 368cab7022b3..3e6045d8858b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java @@ -79,19 +79,13 @@ public class PartitionMetadataAdminDao { */ public static final String COLUMN_FINISHED_AT = "FinishedAt"; - /** Metadata table index for queries over the watermark column. */ - public static final String WATERMARK_INDEX = "WatermarkIndex"; - - /** Metadata table index for queries over the created at / start timestamp columns. */ - public static final String CREATED_AT_START_TIMESTAMP_INDEX = "CreatedAtStartTimestampIndex"; - private static final int TIMEOUT_MINUTES = 10; private static final int TTL_AFTER_PARTITION_FINISHED_DAYS = 1; private final DatabaseAdminClient databaseAdminClient; private final String instanceId; private final String databaseId; - private final String tableName; + private final PartitionMetadataTableNames names; private final Dialect dialect; /** @@ -101,18 +95,18 @@ public class PartitionMetadataAdminDao { * table * @param instanceId the instance where the metadata table will reside * @param databaseId the database where the metadata table will reside - * @param tableName the name of the metadata table + * @param names the names of the metadata table ddl objects */ PartitionMetadataAdminDao( DatabaseAdminClient databaseAdminClient, String instanceId, String databaseId, - String tableName, + PartitionMetadataTableNames names, Dialect dialect) { this.databaseAdminClient = databaseAdminClient; this.instanceId = instanceId; this.databaseId = databaseId; - this.tableName = tableName; + this.names = names; this.dialect = dialect; } @@ -128,8 +122,8 @@ public void createPartitionMetadataTable() { if (this.isPostgres()) { // Literals need be added around literals to preserve casing. ddl.add( - "CREATE TABLE \"" - + tableName + "CREATE TABLE IF NOT EXISTS \"" + + names.getTableName() + "\"(\"" + COLUMN_PARTITION_TOKEN + "\" text NOT NULL,\"" @@ -163,20 +157,20 @@ public void createPartitionMetadataTable() { + COLUMN_FINISHED_AT + "\""); ddl.add( - "CREATE INDEX \"" - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getWatermarkIndexName() + "\" on \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_WATERMARK + "\") INCLUDE (\"" + COLUMN_STATE + "\")"); ddl.add( - "CREATE INDEX \"" - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getCreatedAtIndexName() + "\" ON \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_CREATED_AT + "\",\"" @@ -184,8 +178,8 @@ public void createPartitionMetadataTable() { + "\")"); } else { ddl.add( - "CREATE TABLE " - + tableName + "CREATE TABLE IF NOT EXISTS " + + names.getTableName() + " (" + COLUMN_PARTITION_TOKEN + " STRING(MAX) NOT NULL," @@ -218,20 +212,20 @@ public void createPartitionMetadataTable() { + TTL_AFTER_PARTITION_FINISHED_DAYS + " DAY))"); ddl.add( - "CREATE INDEX " - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getWatermarkIndexName() + " on " - + tableName + + names.getTableName() + " (" + COLUMN_WATERMARK + ") STORING (" + COLUMN_STATE + ")"); ddl.add( - "CREATE INDEX " - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getCreatedAtIndexName() + " ON " - + tableName + + names.getTableName() + " (" + COLUMN_CREATED_AT + "," @@ -261,16 +255,14 @@ public void createPartitionMetadataTable() { * Drops the metadata table. This operation should complete in {@link * PartitionMetadataAdminDao#TIMEOUT_MINUTES} minutes. */ - public void deletePartitionMetadataTable() { + public void deletePartitionMetadataTable(List indexes) { List ddl = new ArrayList<>(); if (this.isPostgres()) { - ddl.add("DROP INDEX \"" + CREATED_AT_START_TIMESTAMP_INDEX + "\""); - ddl.add("DROP INDEX \"" + WATERMARK_INDEX + "\""); - ddl.add("DROP TABLE \"" + tableName + "\""); + indexes.forEach(index -> ddl.add("DROP INDEX \"" + index + "\"")); + ddl.add("DROP TABLE \"" + names.getTableName() + "\""); } else { - ddl.add("DROP INDEX " + CREATED_AT_START_TIMESTAMP_INDEX); - ddl.add("DROP INDEX " + WATERMARK_INDEX); - ddl.add("DROP TABLE " + tableName); + indexes.forEach(index -> ddl.add("DROP INDEX " + index)); + ddl.add("DROP TABLE " + names.getTableName()); } OperationFuture op = databaseAdminClient.updateDatabaseDdl(instanceId, databaseId, ddl, null); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java index 7867932cd1ad..654fd946663c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java @@ -96,6 +96,41 @@ public boolean tableExists() { } } + /** + * Finds all indexes for the metadata table. + * + * @return a list of index names for the metadata table. + */ + public List findAllTableIndexes() { + String indexesStmt; + if (this.isPostgres()) { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = 'public'" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } else { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = ''" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } + + List result = new ArrayList<>(); + try (ResultSet queryResultSet = + databaseClient + .singleUseReadOnlyTransaction() + .executeQuery(Statement.of(indexesStmt), Options.tag("query=findAllTableIndexes"))) { + while (queryResultSet.next()) { + result.add(queryResultSet.getString("index_name")); + } + } + return result; + } + /** * Fetches the partition metadata row data for the given partition token. * diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java new file mode 100644 index 000000000000..07d7b80676de --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import java.io.Serializable; +import java.util.Objects; +import java.util.UUID; +import javax.annotation.Nullable; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** + * Configuration for a partition metadata table. It encapsulates the name of the metadata table and + * indexes. + */ +public class PartitionMetadataTableNames implements Serializable { + + private static final long serialVersionUID = 8848098877671834584L; + + /** PostgreSQL max table and index length is 63 bytes. */ + @VisibleForTesting static final int MAX_NAME_LENGTH = 63; + + private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; + private static final String WATERMARK_INDEX_NAME_FORMAT = "WatermarkIdx_%s_%s"; + private static final String CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT = "CreatedAtIdx_%s_%s"; + + /** + * Generates a unique name for the partition metadata table and its indexes. The table name will + * be in the form of {@code "Metadata__"}. The watermark index will be in the + * form of {@code "WatermarkIdx__}. The createdAt / start timestamp index will + * be in the form of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id where the table will be created + * @return the unique generated names of the partition metadata ddl + */ + public static PartitionMetadataTableNames generateRandom(String databaseId) { + UUID uuid = UUID.randomUUID(); + + String table = generateName(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, uuid); + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + /** + * Encapsulates a selected table name. Index names are generated, but will only be used if the + * given table does not exist. The watermark index will be in the form of {@code + * "WatermarkIdx__}. The createdAt / start timestamp index will be in the form + * of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id for the table + * @param table The table name to be used + * @return an instance with the table name and generated index names + */ + public static PartitionMetadataTableNames fromExistingTable(String databaseId, String table) { + UUID uuid = UUID.randomUUID(); + + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + private static String generateName(String template, String databaseId, UUID uuid) { + String name = String.format(template, databaseId, uuid).replaceAll("-", "_"); + if (name.length() > MAX_NAME_LENGTH) { + return name.substring(0, MAX_NAME_LENGTH); + } + return name; + } + + private final String tableName; + private final String watermarkIndexName; + private final String createdAtIndexName; + + public PartitionMetadataTableNames( + String tableName, String watermarkIndexName, String createdAtIndexName) { + this.tableName = tableName; + this.watermarkIndexName = watermarkIndexName; + this.createdAtIndexName = createdAtIndexName; + } + + public String getTableName() { + return tableName; + } + + public String getWatermarkIndexName() { + return watermarkIndexName; + } + + public String getCreatedAtIndexName() { + return createdAtIndexName; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PartitionMetadataTableNames)) { + return false; + } + PartitionMetadataTableNames that = (PartitionMetadataTableNames) o; + return Objects.equals(tableName, that.tableName) + && Objects.equals(watermarkIndexName, that.watermarkIndexName) + && Objects.equals(createdAtIndexName, that.createdAtIndexName); + } + + @Override + public int hashCode() { + return Objects.hash(tableName, watermarkIndexName, createdAtIndexName); + } + + @Override + public String toString() { + return "PartitionMetadataTableNames{" + + "tableName='" + + tableName + + '\'' + + ", watermarkIndexName='" + + watermarkIndexName + + '\'' + + ", createdAtIndexName='" + + createdAtIndexName + + '\'' + + '}'; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java index a048c885a001..f8aa497292bf 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn; import java.io.Serializable; +import java.util.List; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; import org.apache.beam.sdk.transforms.DoFn; @@ -33,6 +34,7 @@ public CleanUpReadChangeStreamDoFn(DaoFactory daoFactory) { @ProcessElement public void processElement(OutputReceiver receiver) { - daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(); + List indexes = daoFactory.getPartitionMetadataDao().findAllTableIndexes(); + daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(indexes); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java index 387ffd603b14..ca93f34bf1ba 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java @@ -64,6 +64,7 @@ public InitializeDoFn( public void processElement(OutputReceiver receiver) { PartitionMetadataDao partitionMetadataDao = daoFactory.getPartitionMetadataDao(); if (!partitionMetadataDao.tableExists()) { + // Creates partition metadata table and associated indexes daoFactory.getPartitionMetadataAdminDao().createPartitionMetadataTable(); createFakeParentPartition(); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java deleted file mode 100644 index f15fc5307374..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.junit.Test; - -public class NameGeneratorTest { - private static final int MAXIMUM_POSTGRES_TABLE_NAME_LENGTH = 63; - - @Test - public void testGenerateMetadataTableNameRemovesHyphens() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id-12345"); - assertFalse(tableName.contains("-")); - } - - @Test - public void testGenerateMetadataTableNameIsShorterThan64Characters() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id1-maximum-length"); - assertTrue(tableName.length() <= MAXIMUM_POSTGRES_TABLE_NAME_LENGTH); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java index 3752c2fb3afc..02b9d111583b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java @@ -33,7 +33,9 @@ import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -58,6 +60,8 @@ public class PartitionMetadataAdminDaoTest { private static final String DATABASE_ID = "SPANNER_DATABASE"; private static final String TABLE_NAME = "SPANNER_TABLE"; + private static final String WATERMARK_INDEX_NAME = "WATERMARK_INDEX"; + private static final String CREATED_AT_INDEX_NAME = "CREATED_AT_INDEX"; private static final int TIMEOUT_MINUTES = 10; @@ -68,12 +72,14 @@ public class PartitionMetadataAdminDaoTest { @Before public void setUp() { databaseAdminClient = mock(DatabaseAdminClient.class); + PartitionMetadataTableNames names = + new PartitionMetadataTableNames(TABLE_NAME, WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME); partitionMetadataAdminDao = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.GOOGLE_STANDARD_SQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.GOOGLE_STANDARD_SQL); partitionMetadataAdminDaoPostgres = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.POSTGRESQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.POSTGRESQL); op = (OperationFuture) mock(OperationFuture.class); statements = ArgumentCaptor.forClass(Iterable.class); when(databaseAdminClient.updateDatabaseDdl( @@ -89,9 +95,9 @@ public void testCreatePartitionMetadataTable() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE")); - assertTrue(it.next().contains("CREATE INDEX")); - assertTrue(it.next().contains("CREATE INDEX")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); } @Test @@ -102,9 +108,9 @@ public void testCreatePartitionMetadataTablePostgres() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); } @Test @@ -133,7 +139,8 @@ public void testCreatePartitionMetadataTableWithInterruptedException() throws Ex @Test public void testDeletePartitionMetadataTable() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -143,10 +150,22 @@ public void testDeletePartitionMetadataTable() throws Exception { assertTrue(it.next().contains("DROP TABLE")); } + @Test + public void testDeletePartitionMetadataTableWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDao.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE")); + } + @Test public void testDeletePartitionMetadataTablePostgres() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -156,11 +175,23 @@ public void testDeletePartitionMetadataTablePostgres() throws Exception { assertTrue(it.next().contains("DROP TABLE \"")); } + @Test + public void testDeletePartitionMetadataTablePostgresWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE \"")); + } + @Test public void testDeletePartitionMetadataTableWithTimeoutException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new TimeoutException(TIMED_OUT)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertTrue(e.getMessage().contains(TIMED_OUT)); @@ -171,7 +202,8 @@ public void testDeletePartitionMetadataTableWithTimeoutException() throws Except public void testDeletePartitionMetadataTableWithInterruptedException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new InterruptedException(INTERRUPTED)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertEquals(ErrorCode.CANCELLED, e.getErrorCode()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java new file mode 100644 index 000000000000..2aae5b26a2cb --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import static org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames.MAX_NAME_LENGTH; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class PartitionMetadataTableNamesTest { + @Test + public void testGeneratePartitionMetadataNamesRemovesHyphens() { + String databaseId = "my-database-id-12345"; + + PartitionMetadataTableNames names1 = PartitionMetadataTableNames.generateRandom(databaseId); + assertFalse(names1.getTableName().contains("-")); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = PartitionMetadataTableNames.generateRandom(databaseId); + assertNotEquals(names1.getTableName(), names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } + + @Test + public void testGeneratePartitionMetadataNamesIsShorterThan64Characters() { + PartitionMetadataTableNames names = + PartitionMetadataTableNames.generateRandom( + "my-database-id-larger-than-maximum-length-1234567890-1234567890-1234567890"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + + names = PartitionMetadataTableNames.generateRandom("d"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + } + + @Test + public void testPartitionMetadataNamesFromExistingTable() { + PartitionMetadataTableNames names1 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names1.getTableName()); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } +} From 68f1543b6bfe4ceaa752c7f16fc2cae7393211fd Mon Sep 17 00:00:00 2001 From: martin trieu Date: Mon, 21 Oct 2024 03:05:43 -0600 Subject: [PATCH 76/82] Simplify budget distribution logic and new worker metadata consumption (#32775) --- .../FanOutStreamingEngineWorkerHarness.java | 379 ++++++++-------- .../harness/GlobalDataStreamSender.java | 63 +++ ...tate.java => StreamingEngineBackends.java} | 30 +- .../harness/WindmillStreamSender.java | 25 +- .../worker/windmill/WindmillEndpoints.java | 28 +- .../windmill/WindmillServiceAddress.java | 22 +- .../windmill/client/WindmillStream.java | 7 +- .../client/grpc/GrpcDirectGetWorkStream.java | 286 ++++++++----- .../client/grpc/GrpcGetDataStream.java | 2 +- .../client/grpc/GrpcGetWorkStream.java | 10 +- .../grpc/GrpcWindmillStreamFactory.java | 6 +- .../grpc/stubs/WindmillChannelFactory.java | 17 +- .../budget/EvenGetWorkBudgetDistributor.java | 59 +-- .../budget/GetWorkBudgetDistributors.java | 6 +- .../work/budget/GetWorkBudgetSpender.java | 8 +- .../dataflow/worker/FakeWindmillServer.java | 10 +- ...anOutStreamingEngineWorkerHarnessTest.java | 111 ++--- .../harness/WindmillStreamSenderTest.java | 4 +- .../grpc/GrpcDirectGetWorkStreamTest.java | 405 ++++++++++++++++++ .../EvenGetWorkBudgetDistributorTest.java | 186 ++------ 20 files changed, 998 insertions(+), 666 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java rename runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/{StreamingEngineConnectionState.java => StreamingEngineBackends.java} (55%) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 3556b7ce2919..458cf57ca8e7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,20 +20,25 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; -import java.util.Collection; -import java.util.List; +import java.io.Closeable; +import java.util.HashSet; import java.util.Map.Entry; +import java.util.NoSuchElementException; import java.util.Optional; -import java.util.Queue; -import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import java.util.stream.Collectors; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -54,18 +59,14 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; -import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,32 +81,39 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorkerHarness { private static final Logger LOG = LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class); - private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = "PublishNewWorkerMetadataThread"; - private static final String CONSUME_NEW_WORKER_METADATA_THREAD = "ConsumeNewWorkerMetadataThread"; + private static final String WORKER_METADATA_CONSUMER_THREAD_NAME = + "WindmillWorkerMetadataConsumerThread"; + private static final String STREAM_MANAGER_THREAD_NAME = "WindmillStreamManager-%d"; private final JobHeader jobHeader; private final GrpcWindmillStreamFactory streamFactory; private final WorkItemScheduler workItemScheduler; private final ChannelCachingStubFactory channelCachingStubFactory; private final GrpcDispatcherClient dispatcherClient; - private final AtomicBoolean isBudgetRefreshPaused; - private final GetWorkBudgetRefresher getWorkBudgetRefresher; - private final AtomicReference lastBudgetRefresh; + private final GetWorkBudgetDistributor getWorkBudgetDistributor; + private final GetWorkBudget totalGetWorkBudget; private final ThrottleTimer getWorkerMetadataThrottleTimer; - private final ExecutorService newWorkerMetadataPublisher; - private final ExecutorService newWorkerMetadataConsumer; - private final long clientId; - private final Supplier getWorkerMetadataStream; - private final Queue newWindmillEndpoints; private final Function workCommitterFactory; private final ThrottlingGetDataMetricTracker getDataMetricTracker; + private final ExecutorService windmillStreamManager; + private final ExecutorService workerMetadataConsumer; + private final Object metadataLock = new Object(); /** Writes are guarded by synchronization, reads are lock free. */ - private final AtomicReference connections; + private final AtomicReference backends; - private volatile boolean started; + @GuardedBy("this") + private long activeMetadataVersion; + + @GuardedBy("metadataLock") + private long pendingMetadataVersion; + + @GuardedBy("this") + private boolean started; + + @GuardedBy("this") + private @Nullable GetWorkerMetadataStream getWorkerMetadataStream; - @SuppressWarnings("FutureReturnValueIgnored") private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, @@ -114,7 +122,6 @@ private FanOutStreamingEngineWorkerHarness( ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { this.jobHeader = jobHeader; @@ -122,42 +129,21 @@ private FanOutStreamingEngineWorkerHarness( this.started = false; this.streamFactory = streamFactory; this.workItemScheduler = workItemScheduler; - this.connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); + this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY); this.channelCachingStubFactory = channelCachingStubFactory; this.dispatcherClient = dispatcherClient; - this.isBudgetRefreshPaused = new AtomicBoolean(false); this.getWorkerMetadataThrottleTimer = new ThrottleTimer(); - this.newWorkerMetadataPublisher = - singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD); - this.newWorkerMetadataConsumer = - singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD); - this.clientId = clientId; - this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH); - this.newWindmillEndpoints = Queues.synchronizedQueue(EvictingQueue.create(1)); - this.getWorkBudgetRefresher = - new GetWorkBudgetRefresher( - isBudgetRefreshPaused::get, - () -> { - getWorkBudgetDistributor.distributeBudget( - connections.get().windmillStreams().values(), totalGetWorkBudget); - lastBudgetRefresh.set(Instant.now()); - }); - this.getWorkerMetadataStream = - Suppliers.memoize( - () -> - streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStubBlocking(), - getWorkerMetadataThrottleTimer, - endpoints -> - // Run this on a separate thread than the grpc stream thread. - newWorkerMetadataPublisher.submit( - () -> newWindmillEndpoints.add(endpoints)))); + this.windmillStreamManager = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build()); + this.workerMetadataConsumer = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build()); + this.getWorkBudgetDistributor = getWorkBudgetDistributor; + this.totalGetWorkBudget = totalGetWorkBudget; + this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; - } - - private static ExecutorService singleThreadedExecutorServiceOf(String threadName) { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); + this.getWorkerMetadataStream = null; } /** @@ -183,7 +169,6 @@ public static FanOutStreamingEngineWorkerHarness create( channelCachingStubFactory, getWorkBudgetDistributor, dispatcherClient, - /* clientId= */ new Random().nextLong(), workCommitterFactory, getDataMetricTracker); } @@ -197,7 +182,6 @@ static FanOutStreamingEngineWorkerHarness forTesting( ChannelCachingStubFactory stubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkProvider = @@ -209,201 +193,218 @@ static FanOutStreamingEngineWorkerHarness forTesting( stubFactory, getWorkBudgetDistributor, dispatcherClient, - clientId, workCommitterFactory, getDataMetricTracker); fanOutStreamingEngineWorkProvider.start(); return fanOutStreamingEngineWorkProvider; } - @SuppressWarnings("ReturnValueIgnored") @Override public synchronized void start() { - Preconditions.checkState(!started, "StreamingEngineClient cannot start twice."); - // Starts the stream, this value is memoized. - getWorkerMetadataStream.get(); - startWorkerMetadataConsumer(); - getWorkBudgetRefresher.start(); + Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); + getWorkerMetadataStream = + streamFactory.createGetWorkerMetadataStream( + dispatcherClient.getWindmillMetadataServiceStubBlocking(), + getWorkerMetadataThrottleTimer, + this::consumeWorkerMetadata); started = true; } public ImmutableSet currentWindmillEndpoints() { - return connections.get().windmillConnections().keySet().stream() + return backends.get().windmillStreams().keySet().stream() .map(Endpoint::directEndpoint) .filter(Optional::isPresent) .map(Optional::get) - .filter( - windmillServiceAddress -> - windmillServiceAddress.getKind() != WindmillServiceAddress.Kind.IPV6) - .map( - windmillServiceAddress -> - windmillServiceAddress.getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS - ? windmillServiceAddress.gcpServiceAddress() - : windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress()) + .map(WindmillServiceAddress::getServiceAddress) .collect(toImmutableSet()); } /** - * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link - * GetDataStream} pointing to dispatcher. + * Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link + * NoSuchElementException} if one is not found. */ private GetDataStream getGlobalDataStream(String globalDataKey) { - return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey)) - .map(Supplier::get) - .orElseGet( - () -> - streamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), new ThrottleTimer())); - } - - @SuppressWarnings("FutureReturnValueIgnored") - private void startWorkerMetadataConsumer() { - newWorkerMetadataConsumer.submit( - () -> { - while (true) { - Optional.ofNullable(newWindmillEndpoints.poll()) - .ifPresent(this::consumeWindmillWorkerEndpoints); - } - }); + return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey)) + .map(GlobalDataStreamSender::get) + .orElseThrow( + () -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey)); } @VisibleForTesting @Override public synchronized void shutdown() { - Preconditions.checkState(started, "StreamingEngineClient never started."); - getWorkerMetadataStream.get().halfClose(); - getWorkBudgetRefresher.stop(); - newWorkerMetadataPublisher.shutdownNow(); - newWorkerMetadataConsumer.shutdownNow(); + Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); + Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); + workerMetadataConsumer.shutdownNow(); + closeStreamsNotIn(WindmillEndpoints.none()); channelCachingStubFactory.shutdown(); + + try { + Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", e); + } + + windmillStreamManager.shutdown(); + boolean isStreamManagerShutdown = false; + try { + isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", e); + } + if (!isStreamManagerShutdown) { + windmillStreamManager.shutdownNow(); + } + } + + private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) { + synchronized (metadataLock) { + // Only process versions greater than what we currently have to prevent double processing of + // metadata. workerMetadataConsumer is single-threaded so we maintain ordering. + if (windmillEndpoints.version() > pendingMetadataVersion) { + pendingMetadataVersion = windmillEndpoints.version(); + workerMetadataConsumer.execute(() -> consumeWindmillWorkerEndpoints(windmillEndpoints)); + } + } } - /** - * {@link java.util.function.Consumer} used to update {@link #connections} on - * new backend worker metadata. - */ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWindmillEndpoints) { - isBudgetRefreshPaused.set(true); - LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints); - ImmutableMap newWindmillConnections = - createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints()); - - StreamingEngineConnectionState newConnectionsState = - StreamingEngineConnectionState.builder() - .setWindmillConnections(newWindmillConnections) - .setWindmillStreams( - closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values())) + // Since this is run on a single threaded executor, multiple versions of the metadata maybe + // queued up while a previous version of the windmillEndpoints were being consumed. Only consume + // the endpoints if they are the most current version. + synchronized (metadataLock) { + if (newWindmillEndpoints.version() < pendingMetadataVersion) { + return; + } + } + + LOG.debug( + "Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}", + newWindmillEndpoints, + activeMetadataVersion, + newWindmillEndpoints.version()); + closeStreamsNotIn(newWindmillEndpoints); + ImmutableMap newStreams = + createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join(); + StreamingEngineBackends newBackends = + StreamingEngineBackends.builder() + .setWindmillStreams(newStreams) .setGlobalDataStreams( createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints())) .build(); + backends.set(newBackends); + getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget); + activeMetadataVersion = newWindmillEndpoints.version(); + } + + /** Close the streams that are no longer valid asynchronously. */ + private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { + StreamingEngineBackends currentBackends = backends.get(); + currentBackends.windmillStreams().entrySet().stream() + .filter( + connectionAndStream -> + !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) + .forEach( + entry -> + windmillStreamManager.execute( + () -> closeStreamSender(entry.getKey(), entry.getValue()))); - LOG.info( - "Setting new connections: {}. Previous connections: {}.", - newConnectionsState, - connections.get()); - connections.set(newConnectionsState); - isBudgetRefreshPaused.set(false); - getWorkBudgetRefresher.requestBudgetRefresh(); + Set newGlobalDataEndpoints = + new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); + currentBackends.globalDataStreams().values().stream() + .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) + .forEach( + sender -> + windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); + } + + private void closeStreamSender(Endpoint endpoint, Closeable sender) { + LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender); + try { + sender.close(); + endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove); + LOG.debug("Successfully closed streams to {}", endpoint); + } catch (Exception e) { + LOG.error("Error closing streams to endpoint={}, sender={}", endpoint, sender); + } + } + + private synchronized CompletableFuture> + createAndStartNewStreams(ImmutableSet newWindmillEndpoints) { + ImmutableMap currentStreams = backends.get().windmillStreams(); + return MoreFutures.allAsList( + newWindmillEndpoints.stream() + .map(endpoint -> getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams)) + .collect(Collectors.toList())) + .thenApply( + backends -> backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight))) + .toCompletableFuture(); + } + + private CompletionStage> + getOrCreateWindmillStreamSenderFuture( + Endpoint endpoint, ImmutableMap currentStreams) { + return MoreFutures.supplyAsync( + () -> + Pair.of( + endpoint, + Optional.ofNullable(currentStreams.get(endpoint)) + .orElseGet(() -> createAndStartWindmillStreamSender(endpoint))), + windmillStreamManager); } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ - public long getAndResetThrottleTimes() { - return connections.get().windmillStreams().values().stream() + public long getAndResetThrottleTime() { + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getAndResetThrottleTime) .reduce(0L, Long::sum) + getWorkerMetadataThrottleTimer.getAndResetThrottleTime(); } public long currentActiveCommitBytes() { - return connections.get().windmillStreams().values().stream() + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getCurrentActiveCommitBytes) .reduce(0L, Long::sum); } @VisibleForTesting - StreamingEngineConnectionState getCurrentConnections() { - return connections.get(); - } - - private synchronized ImmutableMap createNewWindmillConnections( - List newWindmillEndpoints) { - ImmutableMap currentConnections = - connections.get().windmillConnections(); - return newWindmillEndpoints.stream() - .collect( - toImmutableMap( - Function.identity(), - endpoint -> - // Reuse existing stubs if they exist. Optional.orElseGet only calls the - // supplier if the value is not present, preventing constructing expensive - // objects. - Optional.ofNullable(currentConnections.get(endpoint)) - .orElseGet( - () -> WindmillConnection.from(endpoint, this::createWindmillStub)))); + StreamingEngineBackends currentBackends() { + return backends.get(); } - private synchronized ImmutableMap - closeStaleStreamsAndCreateNewStreams(Collection newWindmillConnections) { - ImmutableMap currentStreams = - connections.get().windmillStreams(); - - // Close the streams that are no longer valid. - currentStreams.entrySet().stream() - .filter( - connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey())) - .forEach( - entry -> { - entry.getValue().closeAllStreams(); - entry.getKey().directEndpoint().ifPresent(channelCachingStubFactory::remove); - }); - - return newWindmillConnections.stream() - .collect( - toImmutableMap( - Function.identity(), - newConnection -> - Optional.ofNullable(currentStreams.get(newConnection)) - .orElseGet(() -> createAndStartWindmillStreamSenderFor(newConnection)))); - } - - private ImmutableMap> createNewGlobalDataStreams( + private ImmutableMap createNewGlobalDataStreams( ImmutableMap newGlobalDataEndpoints) { - ImmutableMap> currentGlobalDataStreams = - connections.get().globalDataStreams(); + ImmutableMap currentGlobalDataStreams = + backends.get().globalDataStreams(); return newGlobalDataEndpoints.entrySet().stream() .collect( toImmutableMap( Entry::getKey, keyedEndpoint -> - existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams))); + getOrCreateGlobalDataSteam(keyedEndpoint, currentGlobalDataStreams))); } - private Supplier existingOrNewGetDataStreamFor( + private GlobalDataStreamSender getOrCreateGlobalDataSteam( Entry keyedEndpoint, - ImmutableMap> currentGlobalDataStreams) { - return Preconditions.checkNotNull( - currentGlobalDataStreams.getOrDefault( - keyedEndpoint.getKey(), + ImmutableMap currentGlobalDataStreams) { + return Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey())) + .orElseGet( () -> - streamFactory.createGetDataStream( - newOrExistingStubFor(keyedEndpoint.getValue()), new ThrottleTimer()))); - } - - private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint endpoint) { - return Optional.ofNullable(connections.get().windmillConnections().get(endpoint)) - .map(WindmillConnection::stub) - .orElseGet(() -> createWindmillStub(endpoint)); + new GlobalDataStreamSender( + () -> + streamFactory.createGetDataStream( + createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), + keyedEndpoint.getValue())); } - private WindmillStreamSender createAndStartWindmillStreamSenderFor( - WindmillConnection connection) { - // Initially create each stream with no budget. The budget will be eventually assigned by the - // GetWorkBudgetDistributor. + private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( - connection, + WindmillConnection.from(endpoint, this::createWindmillStub), GetWorkRequest.newBuilder() - .setClientId(clientId) + .setClientId(jobHeader.getClientId()) .setJobId(jobHeader.getJobId()) .setProjectId(jobHeader.getProjectId()) .setWorkerId(jobHeader.getWorkerId()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java new file mode 100644 index 000000000000..ce5f3a7b6bfc --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import java.io.Closeable; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; + +@Internal +@ThreadSafe +// TODO (m-trieu): replace Supplier with Stream after github.com/apache/beam/pull/32774/ is +// merged +final class GlobalDataStreamSender implements Closeable, Supplier { + private final Endpoint endpoint; + private final Supplier delegate; + private volatile boolean started; + + GlobalDataStreamSender(Supplier delegate, Endpoint endpoint) { + // Ensures that the Supplier is thread-safe + this.delegate = Suppliers.memoize(delegate::get); + this.started = false; + this.endpoint = endpoint; + } + + @Override + public GetDataStream get() { + if (!started) { + started = true; + } + + return delegate.get(); + } + + @Override + public void close() { + if (started) { + delegate.get().shutdown(); + } + } + + Endpoint endpoint() { + return endpoint; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java similarity index 55% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java index 3c85ee6abe1f..14290b486830 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java @@ -18,47 +18,37 @@ package org.apache.beam.runners.dataflow.worker.streaming.harness; import com.google.auto.value.AutoValue; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; /** - * Represents the current state of connections to Streaming Engine. Connections are updated when - * backend workers assigned to the key ranges being processed by this user worker change during + * Represents the current state of connections to the Streaming Engine backend. Backends are updated + * when backend workers assigned to the key ranges being processed by this user worker change during * pipeline execution. For example, changes can happen via autoscaling, load-balancing, or other * backend updates. */ @AutoValue -abstract class StreamingEngineConnectionState { - static final StreamingEngineConnectionState EMPTY = builder().build(); +abstract class StreamingEngineBackends { + static final StreamingEngineBackends EMPTY = builder().build(); static Builder builder() { - return new AutoValue_StreamingEngineConnectionState.Builder() - .setWindmillConnections(ImmutableMap.of()) + return new AutoValue_StreamingEngineBackends.Builder() .setWindmillStreams(ImmutableMap.of()) .setGlobalDataStreams(ImmutableMap.of()); } - abstract ImmutableMap windmillConnections(); - - abstract ImmutableMap windmillStreams(); + abstract ImmutableMap windmillStreams(); /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */ - abstract ImmutableMap> globalDataStreams(); + abstract ImmutableMap globalDataStreams(); @AutoValue.Builder abstract static class Builder { - public abstract Builder setWindmillConnections( - ImmutableMap value); - - public abstract Builder setWindmillStreams( - ImmutableMap value); + public abstract Builder setWindmillStreams(ImmutableMap value); public abstract Builder setGlobalDataStreams( - ImmutableMap> value); + ImmutableMap value); - public abstract StreamingEngineConnectionState build(); + public abstract StreamingEngineBackends build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 45aa403ee71b..744c3d74445f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import java.io.Closeable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -49,7 +50,7 @@ * {@link GetWorkBudget} is set. * *

Once started, the underlying streams are "alive" until they are manually closed via {@link - * #closeAllStreams()}. + * #close()} ()}. * *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, * these instances are not reused. @@ -59,7 +60,7 @@ */ @Internal @ThreadSafe -final class WindmillStreamSender implements GetWorkBudgetSpender { +final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable { private final AtomicBoolean started; private final AtomicReference getWorkBudget; private final Supplier getWorkStream; @@ -103,9 +104,9 @@ private WindmillStreamSender( connection, withRequestBudget(getWorkRequest, getWorkBudget.get()), streamingEngineThrottleTimers.getWorkThrottleTimer(), - () -> FixedStreamHeartbeatSender.create(getDataStream.get()), - () -> getDataClientFactory.apply(getDataStream.get()), - workCommitter, + FixedStreamHeartbeatSender.create(getDataStream.get()), + getDataClientFactory.apply(getDataStream.get()), + workCommitter.get(), workItemScheduler)); } @@ -141,7 +142,8 @@ void startStreams() { started.set(true); } - void closeAllStreams() { + @Override + public void close() { // Supplier.get() starts the stream which is an expensive operation as it initiates the // streaming RPCs by possibly making calls over the network. Do not close the streams unless // they have already been started. @@ -154,18 +156,13 @@ void closeAllStreams() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta)); + public void setBudget(long items, long bytes) { + getWorkBudget.set(getWorkBudget.get().apply(items, bytes)); if (started.get()) { - getWorkStream.get().adjustBudget(itemsDelta, bytesDelta); + getWorkStream.get().setBudget(items, bytes); } } - @Override - public GetWorkBudget remainingBudget() { - return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get(); - } - long getAndResetThrottleTime() { return streamingEngineThrottleTimers.getAndResetThrottleTime(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index d7ed83def43e..eb269eef848f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.auto.value.AutoValue; import java.net.Inet6Address; @@ -27,8 +27,8 @@ import java.util.Map; import java.util.Optional; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +41,14 @@ public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + public static WindmillEndpoints none() { + return WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); + } + public static WindmillEndpoints from( Windmill.WorkerMetadataResponse workerMetadataResponseProto) { ImmutableMap globalDataServers = @@ -53,14 +61,15 @@ public static WindmillEndpoints from( endpoint.getValue(), workerMetadataResponseProto.getExternalEndpoint()))); - ImmutableList windmillServers = + ImmutableSet windmillServers = workerMetadataResponseProto.getWorkEndpointsList().stream() .map( endpointProto -> Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint())) - .collect(toImmutableList()); + .collect(toImmutableSet()); return WindmillEndpoints.builder() + .setVersion(workerMetadataResponseProto.getMetadataVersion()) .setGlobalDataEndpoints(globalDataServers) .setWindmillEndpoints(windmillServers) .build(); @@ -123,6 +132,9 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); } + /** Version of the endpoints which increases with every modification. */ + public abstract long version(); + /** * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key * is a global data tag and the value is the endpoint where the data associated with the global @@ -138,7 +150,7 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( * Windmill servers. Returns a list of endpoints used to communicate with the corresponding * Windmill servers. */ - public abstract ImmutableList windmillEndpoints(); + public abstract ImmutableSet windmillEndpoints(); /** * Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with @@ -204,13 +216,15 @@ public abstract static class Builder { @AutoValue.Builder public abstract static class Builder { + public abstract Builder setVersion(long version); + public abstract Builder setGlobalDataEndpoints( ImmutableMap globalDataServers); public abstract Builder setWindmillEndpoints( - ImmutableList windmillServers); + ImmutableSet windmillServers); - abstract ImmutableList.Builder windmillEndpointsBuilder(); + abstract ImmutableSet.Builder windmillEndpointsBuilder(); public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) { windmillEndpointsBuilder().add(endpoint); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java index 90f93b072673..0b895652efe2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java @@ -19,38 +19,36 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; -import java.net.Inet6Address; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; /** Used to create channels to communicate with Streaming Engine via gRpc. */ @AutoOneOf(WindmillServiceAddress.Kind.class) public abstract class WindmillServiceAddress { - public static WindmillServiceAddress create(Inet6Address ipv6Address) { - return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address); - } public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) { return AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress); } - public abstract Kind getKind(); + public static WindmillServiceAddress create( + AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { + return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( + authenticatedGcpServiceAddress); + } - public abstract Inet6Address ipv6(); + public abstract Kind getKind(); public abstract HostAndPort gcpServiceAddress(); public abstract AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress(); - public static WindmillServiceAddress create( - AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { - return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( - authenticatedGcpServiceAddress); + public final HostAndPort getServiceAddress() { + return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS + ? gcpServiceAddress() + : authenticatedGcpServiceAddress().gcpServiceAddress(); } public enum Kind { - IPV6, GCP_SERVICE_ADDRESS, - // TODO(m-trieu): Use for direct connections when ALTS is enabled. AUTHENTICATED_GCP_SERVICE_ADDRESS } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 31bd4e146a78..f26c56b14ec2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -56,10 +56,11 @@ public interface WindmillStream { @ThreadSafe interface GetWorkStream extends WindmillStream { /** Adjusts the {@link GetWorkBudget} for the stream. */ - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(GetWorkBudget newBudget); - /** Returns the remaining in-flight {@link GetWorkBudget}. */ - GetWorkBudget remainingBudget(); + default void setBudget(long newItems, long newBytes) { + setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); + } } /** Interface for streaming GetDataRequests to Windmill. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 19de998b1da8..b27ebc8e9eee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -21,9 +21,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.concurrent.GuardedBy; +import net.jcip.annotations.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -44,8 +46,8 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of {@link GetWorkStream} that passes along a specific {@link @@ -55,9 +57,10 @@ * these direct streams are used to facilitate these RPC calls to specific backend workers. */ @Internal -public final class GrpcDirectGetWorkStream +final class GrpcDirectGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class); private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST = StreamingGetWorkRequest.newBuilder() .setRequestExtension( @@ -67,15 +70,14 @@ public final class GrpcDirectGetWorkStream .build()) .build(); - private final AtomicReference inFlightBudget; - private final AtomicReference nextBudgetAdjustment; - private final AtomicReference pendingResponseBudget; - private final GetWorkRequest request; + private final GetWorkBudgetTracker budgetTracker; + private final GetWorkRequest requestHeader; private final WorkItemScheduler workItemScheduler; private final ThrottleTimer getWorkThrottleTimer; - private final Supplier heartbeatSender; - private final Supplier workCommitter; - private final Supplier getDataClient; + private final HeartbeatSender heartbeatSender; + private final WorkCommitter workCommitter; + private final GetDataClient getDataClient; + private final AtomicReference lastRequest; /** * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they @@ -92,15 +94,15 @@ private GrpcDirectGetWorkStream( StreamObserver, StreamObserver> startGetWorkRpcFn, - GetWorkRequest request, + GetWorkRequest requestHeader, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { super( "GetWorkStream", @@ -110,19 +112,23 @@ private GrpcDirectGetWorkStream( streamRegistry, logEveryNStreamFailures, backendWorkerToken); - this.request = request; + this.requestHeader = requestHeader; this.getWorkThrottleTimer = getWorkThrottleTimer; this.workItemScheduler = workItemScheduler; this.workItemAssemblers = new ConcurrentHashMap<>(); - this.heartbeatSender = Suppliers.memoize(heartbeatSender::get); - this.workCommitter = Suppliers.memoize(workCommitter::get); - this.getDataClient = Suppliers.memoize(getDataClient::get); - this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); - this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); - this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + this.heartbeatSender = heartbeatSender; + this.workCommitter = workCommitter; + this.getDataClient = getDataClient; + this.lastRequest = new AtomicReference<>(); + this.budgetTracker = + new GetWorkBudgetTracker( + GetWorkBudget.builder() + .setItems(requestHeader.getMaxItems()) + .setBytes(requestHeader.getMaxBytes()) + .build()); } - public static GrpcDirectGetWorkStream create( + static GrpcDirectGetWorkStream create( String backendWorkerToken, Function< StreamObserver, @@ -134,9 +140,9 @@ public static GrpcDirectGetWorkStream create( Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { GrpcDirectGetWorkStream getWorkStream = new GrpcDirectGetWorkStream( @@ -165,46 +171,52 @@ private static Watermarks createWatermarks( .build(); } - private void sendRequestExtension(GetWorkBudget adjustment) { - inFlightBudget.getAndUpdate(budget -> budget.apply(adjustment)); - StreamingGetWorkRequest extension = - StreamingGetWorkRequest.newBuilder() - .setRequestExtension( - Windmill.StreamingGetWorkRequestExtension.newBuilder() - .setMaxItems(adjustment.items()) - .setMaxBytes(adjustment.bytes())) - .build(); - - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + /** + * @implNote Do not lock/synchronize here due to this running on grpc serial executor for message + * which can deadlock since we send on the stream beneath the synchronization. {@link + * AbstractWindmillStream#send(Object)} is synchronized so the sends are already guarded. + */ + private void maybeSendRequestExtension(GetWorkBudget extension) { + if (extension.items() > 0 || extension.bytes() > 0) { + executeSafely( + () -> { + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(extension.items()) + .setMaxBytes(extension.bytes())) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(extension); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was closed. + } + }); + } } @Override protected synchronized void onNewStream() { workItemAssemblers.clear(); - // Add the current in-flight budget to the next adjustment. Only positive values are allowed - // here - // with negatives defaulting to 0, since GetWorkBudgets cannot be created with negative values. - GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inFlightBudget.get()); - inFlightBudget.set(budgetAdjustment); - send( - StreamingGetWorkRequest.newBuilder() - .setRequest( - request - .toBuilder() - .setMaxBytes(budgetAdjustment.bytes()) - .setMaxItems(budgetAdjustment.items())) - .build()); - - // We just sent the budget, reset it. - nextBudgetAdjustment.set(GetWorkBudget.noBudget()); + if (!isShutdown()) { + budgetTracker.reset(); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequest( + requestHeader + .toBuilder() + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build()) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(initialGetWorkBudget); + send(request); + } } @Override @@ -216,8 +228,9 @@ protected boolean hasPendingRequests() { public void appendSpecificHtml(PrintWriter writer) { // Number of buffers is same as distinct workers that sent work on this stream. writer.format( - "GetWorkStream: %d buffers, %s inflight budget allowed.", - workItemAssemblers.size(), inFlightBudget.get()); + "GetWorkStream: %d buffers, " + "last sent request: %s; ", + workItemAssemblers.size(), lastRequest.get()); + writer.print(budgetTracker.debugString()); } @Override @@ -235,30 +248,22 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) { } private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { - // Record the fact that there are now fewer outstanding messages and bytes on the stream. - inFlightBudget.updateAndGet(budget -> budget.subtract(1, assembledWorkItem.bufferedSize())); WorkItem workItem = assembledWorkItem.workItem(); GetWorkResponseChunkAssembler.ComputationMetadata metadata = assembledWorkItem.computationMetadata(); - pendingResponseBudget.getAndUpdate(budget -> budget.apply(1, workItem.getSerializedSize())); - try { - workItemScheduler.scheduleWork( - workItem, - createWatermarks(workItem, Preconditions.checkNotNull(metadata)), - createProcessingContext(Preconditions.checkNotNull(metadata.computationId())), - assembledWorkItem.latencyAttributions()); - } finally { - pendingResponseBudget.getAndUpdate(budget -> budget.apply(-1, -workItem.getSerializedSize())); - } + workItemScheduler.scheduleWork( + workItem, + createWatermarks(workItem, metadata), + createProcessingContext(metadata.computationId()), + assembledWorkItem.latencyAttributions()); + budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize()); + GetWorkBudget extension = budgetTracker.computeBudgetExtension(); + maybeSendRequestExtension(extension); } private Work.ProcessingContext createProcessingContext(String computationId) { return Work.createProcessingContext( - computationId, - getDataClient.get(), - workCommitter.get()::commit, - heartbeatSender.get(), - backendWorkerToken()); + computationId, getDataClient, workCommitter::commit, heartbeatSender, backendWorkerToken()); } @Override @@ -267,25 +272,110 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - GetWorkBudget adjustment = - nextBudgetAdjustment - // Get the current value, and reset the nextBudgetAdjustment. This will be set again - // when adjustBudget is called. - .getAndUpdate(unused -> GetWorkBudget.noBudget()) - .apply(itemsDelta, bytesDelta); - sendRequestExtension(adjustment); + public void setBudget(GetWorkBudget newBudget) { + GetWorkBudget extension = budgetTracker.consumeAndComputeBudgetUpdate(newBudget); + maybeSendRequestExtension(extension); } - @Override - public GetWorkBudget remainingBudget() { - // Snapshot the current budgets. - GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); - GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); - GetWorkBudget currentInflightBudget = inFlightBudget.get(); - - return currentPendingResponseBudget - .apply(currentNextBudgetAdjustment) - .apply(currentInflightBudget); + private void executeSafely(Runnable runnable) { + try { + executor().execute(runnable); + } catch (RejectedExecutionException e) { + LOG.debug("{} has been shutdown.", getClass()); + } + } + + /** + * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request + * extensions. + */ + @ThreadSafe + private static final class GetWorkBudgetTracker { + + @GuardedBy("GetWorkBudgetTracker.this") + private GetWorkBudget maxGetWorkBudget; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsReceived = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesReceived = 0; + + private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) { + this.maxGetWorkBudget = maxGetWorkBudget; + } + + private synchronized void reset() { + itemsRequested = 0; + bytesRequested = 0; + itemsReceived = 0; + bytesReceived = 0; + } + + private synchronized String debugString() { + return String.format( + "max budget: %s; " + + "in-flight budget: %s; " + + "total budget requested: %s; " + + "total budget received: %s.", + maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), totalReceivedBudget()); + } + + /** Consumes the new budget and computes an extension based on the new budget. */ + private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) { + maxGetWorkBudget = newBudget; + return computeBudgetExtension(); + } + + private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) { + itemsRequested += budgetRequested.items(); + bytesRequested += budgetRequested.bytes(); + } + + private synchronized void recordBudgetReceived(long returnedBudget) { + itemsReceived++; + bytesReceived += returnedBudget; + } + + /** + * If the outstanding items or bytes limit has gotten too low, top both off with a + * GetWorkExtension. The goal is to keep the limits relatively close to their maximum values + * without sending too many extension requests. + */ + private synchronized GetWorkBudget computeBudgetExtension() { + // Expected items and bytes can go negative here, since WorkItems returned might be larger + // than the initially requested budget. + long inFlightItems = itemsRequested - itemsReceived; + long inFlightBytes = bytesRequested - bytesReceived; + + // Don't send negative budget extensions. + long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - inFlightBytes); + long requestItems = Math.max(0, maxGetWorkBudget.items() - inFlightItems); + + return (inFlightItems > requestItems / 2 && inFlightBytes > requestBytes / 2) + ? GetWorkBudget.noBudget() + : GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build(); + } + + private synchronized GetWorkBudget inFlightBudget() { + return GetWorkBudget.builder() + .setItems(itemsRequested - itemsReceived) + .setBytes(bytesRequested - bytesReceived) + .build(); + } + + private synchronized GetWorkBudget totalRequestedBudget() { + return GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build(); + } + + private synchronized GetWorkBudget totalReceivedBudget() { + return GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 0e9a0c6316ee..c99e05a77074 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -59,7 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcGetDataStream +final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 09ecbf3f3051..a368f3fec235 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -194,15 +194,7 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op } - - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setBytes(request.getMaxBytes() - inflightBytes.get()) - .setItems(request.getMaxItems() - inflightMessages.get()) - .build(); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 92f031db9972..9e6a02d135e2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -198,9 +198,9 @@ public GetWorkStream createDirectGetWorkStream( WindmillConnection connection, GetWorkRequest request, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 9aec29a3ba4d..f0ea2f550a74 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -36,7 +36,6 @@ /** Utility class used to create different RPC Channels. */ public final class WindmillChannelFactory { public static final String LOCALHOST = "localhost"; - private static final int DEFAULT_GRPC_PORT = 443; private static final int MAX_REMOTE_TRACE_EVENTS = 100; private WindmillChannelFactory() {} @@ -55,8 +54,6 @@ public static Channel localhostChannel(int port) { public static ManagedChannel remoteChannel( WindmillServiceAddress windmillServiceAddress, int windmillServiceRpcChannelTimeoutSec) { switch (windmillServiceAddress.getKind()) { - case IPV6: - return remoteChannel(windmillServiceAddress.ipv6(), windmillServiceRpcChannelTimeoutSec); case GCP_SERVICE_ADDRESS: return remoteChannel( windmillServiceAddress.gcpServiceAddress(), windmillServiceRpcChannelTimeoutSec); @@ -67,7 +64,8 @@ public static ManagedChannel remoteChannel( windmillServiceRpcChannelTimeoutSec); default: throw new UnsupportedOperationException( - "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported WindmillServiceAddresses."); + "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS are supported" + + " WindmillServiceAddresses."); } } @@ -105,17 +103,6 @@ public static Channel remoteChannel( } } - public static ManagedChannel remoteChannel( - Inet6Address directEndpoint, int windmillServiceRpcChannelTimeoutSec) { - try { - return createRemoteChannel( - NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint, DEFAULT_GRPC_PORT)), - windmillServiceRpcChannelTimeoutSec); - } catch (SSLException sslException) { - throw new WindmillChannelCreationException(directEndpoint.toString(), sslException); - } - } - @SuppressWarnings("nullness") private static ManagedChannel createRemoteChannel( NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java index 403bb99efb4c..8a1ba2556cf2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -17,18 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide; import java.math.RoundingMode; -import java.util.Map; -import java.util.Map.Entry; -import java.util.function.Function; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,22 +29,11 @@ @Internal final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor { private static final Logger LOG = LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class); - private final Supplier activeWorkBudgetSupplier; - - EvenGetWorkBudgetDistributor(Supplier activeWorkBudgetSupplier) { - this.activeWorkBudgetSupplier = activeWorkBudgetSupplier; - } - - private static boolean isBelowFiftyPercentOfTarget( - GetWorkBudget remaining, GetWorkBudget target) { - return remaining.items() < roundToLong(target.items() * 0.5, RoundingMode.CEILING) - || remaining.bytes() < roundToLong(target.bytes() * 0.5, RoundingMode.CEILING); - } @Override public void distributeBudget( - ImmutableCollection budgetOwners, GetWorkBudget getWorkBudget) { - if (budgetOwners.isEmpty()) { + ImmutableCollection budgetSpenders, GetWorkBudget getWorkBudget) { + if (budgetSpenders.isEmpty()) { LOG.debug("Cannot distribute budget to no owners."); return; } @@ -61,38 +43,15 @@ public void distributeBudget( return; } - Map desiredBudgets = computeDesiredBudgets(budgetOwners, getWorkBudget); - - for (Entry streamAndDesiredBudget : desiredBudgets.entrySet()) { - GetWorkBudgetSpender getWorkBudgetSpender = streamAndDesiredBudget.getKey(); - GetWorkBudget desired = streamAndDesiredBudget.getValue(); - GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget(); - if (isBelowFiftyPercentOfTarget(remaining, desired)) { - GetWorkBudget adjustment = desired.subtract(remaining); - getWorkBudgetSpender.adjustBudget(adjustment); - } - } + GetWorkBudget budgetPerStream = computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget); + budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream)); } - private ImmutableMap computeDesiredBudgets( + private GetWorkBudget computeDesiredPerStreamBudget( ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { - GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get(); - LOG.info("Current active work budget: {}", activeWorkBudget); - // TODO: Fix possibly non-deterministic handing out of budgets. - // Rounding up here will drift upwards over the lifetime of the streams. - GetWorkBudget budgetPerStream = - GetWorkBudget.builder() - .setItems( - divide( - totalGetWorkBudget.items() - activeWorkBudget.items(), - streams.size(), - RoundingMode.CEILING)) - .setBytes( - divide( - totalGetWorkBudget.bytes() - activeWorkBudget.bytes(), - streams.size(), - RoundingMode.CEILING)) - .build(); - return streams.stream().collect(toImmutableMap(Function.identity(), unused -> budgetPerStream)); + return GetWorkBudget.builder() + .setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING)) + .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), RoundingMode.CEILING)) + .build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java index 43c0d46139da..2013c9ff1cb7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java @@ -17,13 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; @Internal public final class GetWorkBudgetDistributors { - public static GetWorkBudgetDistributor distributeEvenly( - Supplier activeWorkBudgetSupplier) { - return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier); + public static GetWorkBudgetDistributor distributeEvenly() { + return new EvenGetWorkBudgetDistributor(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java index 254b2589062e..decf101a641b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java @@ -22,11 +22,9 @@ * org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget} */ public interface GetWorkBudgetSpender { - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(long items, long bytes); - default void adjustBudget(GetWorkBudget adjustment) { - adjustBudget(adjustment.items(), adjustment.bytes()); + default void setBudget(GetWorkBudget budget) { + setBudget(budget.items(), budget.bytes()); } - - GetWorkBudget remainingBudget(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index b3f7467cdbd3..90ffb3d3fbcf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -245,18 +245,10 @@ public void halfClose() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op. } - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setItems(request.getMaxItems()) - .setBytes(request.getMaxBytes()) - .build(); - } - @Override public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { while (done.getCount() > 0) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index ed8815c48e76..0092fcc7bcd1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -30,9 +30,7 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashSet; -import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; @@ -46,7 +44,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; @@ -71,7 +68,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.junit.After; import org.junit.Before; @@ -92,7 +88,6 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString()) .build()); - private static final long CLIENT_ID = 1L; private static final String JOB_ID = "jobId"; private static final String PROJECT_ID = "projectId"; private static final String WORKER_ID = "workerId"; @@ -101,6 +96,7 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) + .setClientId(1L) .build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @@ -134,7 +130,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) - .setClientId(CLIENT_ID) + .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) .build(); @@ -174,7 +170,7 @@ public void cleanUp() { stubFactory.shutdown(); } - private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( + private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( GetWorkBudget getWorkBudget, GetWorkBudgetDistributor getWorkBudgetDistributor, WorkItemScheduler workItemScheduler) { @@ -186,7 +182,6 @@ private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( stubFactory, getWorkBudgetDistributor, dispatcherClient, - CLIENT_ID, ignored -> mock(WorkCommitter.class), new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); } @@ -201,7 +196,7 @@ public void testStreamsStartCorrectly() throws InterruptedException { spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(items).setBytes(bytes).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -219,16 +214,14 @@ public void testStreamsStartCorrectly() throws InterruptedException { getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); - assertEquals(2, currentConnections.windmillConnections().size()); - assertEquals(2, currentConnections.windmillStreams().size()); + assertEquals(2, currentBackends.windmillStreams().size()); Set workerTokens = - currentConnections.windmillConnections().values().stream() - .map(WindmillConnection::backendWorkerToken) + currentBackends.windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertTrue(workerTokens.contains(workerToken)); @@ -252,27 +245,6 @@ public void testStreamsStartCorrectly() throws InterruptedException { verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); } - @Test - public void testScheduledBudgetRefresh() throws InterruptedException { - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(2)); - fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( - GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), - getWorkBudgetDistributor, - noOpProcessWorkItemFn()); - - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata( - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(1) - .addWorkEndpoints(metadataResponseEndpoint("workerToken")) - .putAllGlobalDataEndpoints(DEFAULT) - .build()); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any()); - } - @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { @@ -280,7 +252,7 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor(metadataCount)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -309,32 +281,28 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() WorkerMetadataResponse.Endpoint.newBuilder() .setBackendWorkerToken(workerToken3) .build()) - .putAllGlobalDataEndpoints(DEFAULT) .build(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); - assertEquals(1, currentConnections.windmillConnections().size()); - assertEquals(1, currentConnections.windmillStreams().size()); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); + assertEquals(1, currentBackends.windmillStreams().size()); Set workerTokens = - fanOutStreamingEngineWorkProvider.getCurrentConnections().windmillConnections().values() - .stream() - .map(WindmillConnection::backendWorkerToken) + fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertFalse(workerTokens.contains(workerToken)); assertFalse(workerTokens.contains(workerToken2)); + assertTrue(currentBackends.globalDataStreams().isEmpty()); } @Test public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - String workerToken3 = "workerToken3"; WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() @@ -354,42 +322,24 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); - WorkerMetadataResponse thirdWorkerMetadata = - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(3) - .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder() - .setBackendWorkerToken(workerToken3) - .build()) - .putAllGlobalDataEndpoints(DEFAULT) - .build(); - - List workerMetadataResponses = - Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, thirdWorkerMetadata); TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size())); + spy(new TestGetWorkBudgetDistributor(1)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); getWorkerMetadataReady.await(); - // Make sure we are injecting the metadata from smallest to largest. - workerMetadataResponses.stream() - .sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion)) - .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata); - - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size())) - .distributeBudget(any(), any()); - } + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + getWorkBudgetDistributor.expectNumDistributions(1); + fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - private void waitForWorkerMetadataToBeConsumed( - TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException { - getWorkBudgetDistributor.waitForBudgetDistribution(); + verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } private static class GetWorkerMetadataTestStub @@ -434,21 +384,24 @@ private void injectWorkerMetadata(WorkerMetadataResponse response) { } private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { - private final CountDownLatch getWorkBudgetDistributorTriggered; + private CountDownLatch getWorkBudgetDistributorTriggered; private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } - @SuppressWarnings("ReturnValueIgnored") - private void waitForBudgetDistribution() throws InterruptedException { - getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + private boolean waitForBudgetDistribution() throws InterruptedException { + return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + } + + private void expectNumDistributions(int numBudgetDistributionsExpected) { + this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } @Override public void distributeBudget( ImmutableCollection streams, GetWorkBudget getWorkBudget) { - streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes())); + streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), getWorkBudget.bytes())); getWorkBudgetDistributorTriggered.countDown(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index dc6cc5641055..32d1f5738086 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -193,7 +193,7 @@ public void testCloseAllStreams_doesNotCloseUnstartedStreams() { WindmillStreamSender windmillStreamSender = newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verifyNoInteractions(streamFactory); } @@ -230,7 +230,7 @@ public void testCloseAllStreams_closesAllStreams() { mockStreamFactory); windmillStreamSender.startStreams(); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verify(mockGetWorkStream).shutdown(); verify(mockGetDataStream).shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java new file mode 100644 index 000000000000..fd2b30238836 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER = + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setClientId(1L) + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + private static void assertHeader( + Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget expectedInitialBudget) { + assertTrue(getWorkRequest.hasRequest()); + assertFalse(getWorkRequest.hasRequestExtension()); + assertThat(getWorkRequest.getRequest()) + .isEqualTo( + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(expectedInitialBudget.items()) + .setMaxBytes(expectedInitialBudget.bytes()) + .build()); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget()); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget()); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set scheduledWorkItems = new HashSet<>(); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> + scheduledWorkItems.add(work)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), throttleTimer, NO_OP_WORK_ITEM_SCHEDULER); + stream.startThrottleTimer(); + assertTrue(throttleTimer.throttled()); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); + assertFalse(throttleTimer.throttled()); + } + + private static class GetWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetWorkResponseChunk responseChunk) { + checkNotNull(responseObserver).onNext(responseChunk); + } + } + + private static class TestGetWorkRequestObserver + implements StreamObserver { + private final List requests = + Collections.synchronizedList(new ArrayList<>()); + private final CountDownLatch waitForRequests; + private @Nullable volatile StreamObserver + responseObserver; + + public TestGetWorkRequestObserver(CountDownLatch waitForRequests) { + this.waitForRequests = waitForRequests; + } + + @Override + public void onNext(Windmill.StreamingGetWorkRequest request) { + requests.add(request); + waitForRequests.countDown(); + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + + List sent() { + return requests; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index 3cda4559c100..c76d5a584184 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -17,9 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -40,169 +38,79 @@ public class EvenGetWorkBudgetDistributorTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { - return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); - } + private static GetWorkBudgetSpender createGetWorkBudgetOwner() { + // Lambdas are final and cannot be spied. + return spy( + new GetWorkBudgetSpender() { - private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { - return createBudgetDistributor( - GetWorkBudget.builder() - .setItems(activeWorkItemsAndBytes) - .setBytes(activeWorkItemsAndBytes) - .build()); + @Override + public void setBudget(long items, long bytes) {} + }); } @Test public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { - createBudgetDistributor(1L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); } @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())); - createBudgetDistributor(1L) + GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner(); + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget(ImmutableList.of(getWorkBudgetSpender), GetWorkBudget.noBudget()); verifyNoInteractions(getWorkBudgetSpender); } @Test - public void testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(10L).setBytes(10L).build())); - createBudgetDistributor(0L) - .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(5L).setBytes(5L).build())); - createBudgetDistributor(10L) + public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { + int totalStreams = 10; + long totalItems = 10L; + long totalBytes = 100L; + List streams = new ArrayList<>(); + for (int i = 0; i < totalStreams; i++) { + streams.add(createGetWorkBudgetOwner()); + } + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(20L).setBytes(20L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq( - totalGetWorkBudget.items() - - streamRemainingBudget.items() - - activeWorkItemsAndBytes), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); + ImmutableList.copyOf(streams), + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq( - totalGetWorkBudget.bytes() - - streamRemainingBudget.bytes() - - activeWorkItemsAndBytes)); + streams.forEach( + stream -> + verify(stream, times(1)) + .setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build()))); } @Test - public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { - long totalItemsAndBytes = 10L; + public void testDistributeBudget_distributesFairlyWhenNotEven() { + long totalItems = 10L; + long totalBytes = 19L; List streams = new ArrayList<>(); - for (int i = 0; i < totalItemsAndBytes; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < 3; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), - GetWorkBudget.builder() - .setItems(totalItemsAndBytes) - .setBytes(totalItemsAndBytes) - .build()); + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + .setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build()))); } @Test - public void testDistributeBudget_distributesFairlyWhenNotEven() { + public void testDistributeBudget_distributesBudgetEvenly() { long totalItemsAndBytes = 10L; List streams = new ArrayList<>(); - for (int i = 0; i < 3; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < totalItemsAndBytes; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), GetWorkBudget.builder() @@ -210,24 +118,10 @@ public void testDistributeBudget_distributesFairlyWhenNotEven() { .setBytes(totalItemsAndBytes) .build()); - long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / (streams.size() * 1.0)); + long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); - } - - private GetWorkBudgetSpender createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget getWorkBudget) { - return spy( - new GetWorkBudgetSpender() { - @Override - public void adjustBudget(long itemsDelta, long bytesDelta) {} - - @Override - public GetWorkBudget remainingBudget() { - return getWorkBudget; - } - }); + .setBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); } } From 62b083095966fb11be3121abdebe18ab954339d7 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:34:16 -0400 Subject: [PATCH 77/82] Remove Python 3.8 Support from Apache Beam (#32643) * Remove Python 3.8 Support from Apache Beam * Remove artifact build/publishing * Address comments --- .../workflows/beam_Publish_Beam_SDK_Snapshots.yml | 1 - .github/workflows/build_wheels.yml | 4 ++-- .test-infra/jenkins/metrics_report/tox.ini | 2 +- .test-infra/mock-apis/pyproject.toml | 2 +- .test-infra/tools/python_installer.sh | 2 +- .../org/apache/beam/gradle/BeamModulePlugin.groovy | 1 - contributor-docs/release-guide.md | 2 +- gradle.properties | 2 +- local-env-setup.sh | 4 ++-- .../cloudbuild/playground_cd_examples.sh | 10 +++++----- .../cloudbuild/playground_ci_examples.sh | 10 +++++----- release/src/main/Dockerfile | 3 +-- .../main/python-release/python_release_automation.sh | 2 +- sdks/python/apache_beam/__init__.py | 8 +------- .../ml/inference/test_resources/vllm.dockerfile | 2 +- sdks/python/expansion-service-container/Dockerfile | 2 +- sdks/python/setup.py | 12 +++--------- 17 files changed, 27 insertions(+), 42 deletions(-) diff --git a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml index 7107385c1722..e3791119be90 100644 --- a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml +++ b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml @@ -66,7 +66,6 @@ jobs: - "java:container:java11" - "java:container:java17" - "java:container:java21" - - "python:container:py38" - "python:container:py39" - "python:container:py310" - "python:container:py311" diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 0a15ba9d150c..828a6328c0cd 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -49,7 +49,7 @@ jobs: env: EVENT_NAME: ${{ github.event_name }} # Keep in sync with py_version matrix value below - if changed, change that as well. - PY_VERSIONS_FULL: "cp38-* cp39-* cp310-* cp311-* cp312-*" + PY_VERSIONS_FULL: "cp39-* cp310-* cp311-* cp312-*" outputs: gcp-variables-set: ${{ steps.check_gcp_variables.outputs.gcp-variables-set }} py-versions-full: ${{ steps.set-py-versions.outputs.py-versions-full }} @@ -229,7 +229,7 @@ jobs: {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "aarch64" } ] # Keep in sync (remove asterisks) with PY_VERSIONS_FULL env var above - if changed, change that as well. - py_version: ["cp38-", "cp39-", "cp310-", "cp311-", "cp312-"] + py_version: ["cp39-", "cp310-", "cp311-", "cp312-"] steps: - name: Download python source distribution from artifacts uses: actions/download-artifact@v4.1.8 diff --git a/.test-infra/jenkins/metrics_report/tox.ini b/.test-infra/jenkins/metrics_report/tox.ini index 026db5dc4860..d143a0dcf59c 100644 --- a/.test-infra/jenkins/metrics_report/tox.ini +++ b/.test-infra/jenkins/metrics_report/tox.ini @@ -17,7 +17,7 @@ ; TODO(https://github.com/apache/beam/issues/20209): Don't hardcode Py3.8 version. [tox] skipsdist = True -envlist = py38-test,py38-generate-report +envlist = py39-test,py39-generate-report [testenv] commands_pre = diff --git a/.test-infra/mock-apis/pyproject.toml b/.test-infra/mock-apis/pyproject.toml index 680bf489ba13..c98d9152cfb9 100644 --- a/.test-infra/mock-apis/pyproject.toml +++ b/.test-infra/mock-apis/pyproject.toml @@ -27,7 +27,7 @@ packages = [ ] [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" google = "^3.0.0" grpcio = "^1.53.0" grpcio-tools = "^1.53.0" diff --git a/.test-infra/tools/python_installer.sh b/.test-infra/tools/python_installer.sh index b1b05e597cb3..04e10555243a 100644 --- a/.test-infra/tools/python_installer.sh +++ b/.test-infra/tools/python_installer.sh @@ -20,7 +20,7 @@ set -euo pipefail # Variable containing the python versions to install -python_versions_arr=("3.8.16" "3.9.16" "3.10.10" "3.11.4") +python_versions_arr=("3.9.16" "3.10.10" "3.11.4", "3.12.6") # Install pyenv dependencies. pyenv_dep(){ diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 576b8defb40b..8a094fd56217 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -3152,7 +3152,6 @@ class BeamModulePlugin implements Plugin { mustRunAfter = [ ":runners:flink:${project.ext.latestFlinkVersion}:job-server:shadowJar", ':runners:spark:3:job-server:shadowJar', - ':sdks:python:container:py38:docker', ':sdks:python:container:py39:docker', ':sdks:python:container:py310:docker', ':sdks:python:container:py311:docker', diff --git a/contributor-docs/release-guide.md b/contributor-docs/release-guide.md index d351049c96cd..51f06adf50e4 100644 --- a/contributor-docs/release-guide.md +++ b/contributor-docs/release-guide.md @@ -507,7 +507,7 @@ with tags: `${RELEASE_VERSION}rc${RC_NUM}` Verify that third party licenses are included in Docker. You can do this with a simple script: RC_TAG=${RELEASE_VERSION}rc${RC_NUM} - for pyver in 3.8 3.9 3.10 3.11; do + for pyver in 3.9 3.10 3.11 3.12; do docker run --rm --entrypoint sh \ apache/beam_python${pyver}_sdk:${RC_TAG} \ -c 'ls -al /opt/apache/beam/third_party_licenses/ | wc -l' diff --git a/gradle.properties b/gradle.properties index f6e143690a34..4b3a752f0633 100644 --- a/gradle.properties +++ b/gradle.properties @@ -41,4 +41,4 @@ docker_image_default_repo_prefix=beam_ # supported flink versions flink_versions=1.15,1.16,1.17,1.18,1.19 # supported python versions -python_versions=3.8,3.9,3.10,3.11,3.12 +python_versions=3.9,3.10,3.11,3.12 diff --git a/local-env-setup.sh b/local-env-setup.sh index f13dc88432a6..ba30813b2bcc 100755 --- a/local-env-setup.sh +++ b/local-env-setup.sh @@ -55,7 +55,7 @@ if [ "$kernelname" = "Linux" ]; then exit fi - for ver in 3.8 3.9 3.10 3.11 3.12 3; do + for ver in 3.9 3.10 3.11 3.12 3; do apt install --yes python$ver-venv done @@ -89,7 +89,7 @@ elif [ "$kernelname" = "Darwin" ]; then echo "Installing openjdk@8" brew install openjdk@8 fi - for ver in 3.8 3.9 3.10 3.11 3.12; do + for ver in 3.9 3.10 3.11 3.12; do if brew ls --versions python@$ver > /dev/null; then echo "python@$ver already installed. Skipping" brew info python@$ver diff --git a/playground/infrastructure/cloudbuild/playground_cd_examples.sh b/playground/infrastructure/cloudbuild/playground_cd_examples.sh index d05773656b30..e571bc9fc9d9 100644 --- a/playground/infrastructure/cloudbuild/playground_cd_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_cd_examples.sh @@ -97,15 +97,15 @@ LogOutput "Installing python and dependencies." export DEBIAN_FRONTEND=noninteractive apt install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null 2>&1 add-apt-repository -y ppa:deadsnakes/ppa > /dev/null 2>&1 && apt update > /dev/null 2>&1 -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null 2>&1 -apt install -y --reinstall python3.8-distutils > /dev/null 2>&1 +apt install -y python3.9 python3-distutils python3-pip > /dev/null 2>&1 +apt install -y --reinstall python3-distutils > /dev/null 2>&1 apt install -y python3-virtualenv virtualenv play_venv source play_venv/bin/activate pip install --upgrade google-api-python-client > /dev/null 2>&1 -python3.8 -m pip install pip --upgrade > /dev/null 2>&1 -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null 2>&1 -apt install -y python3.8-venv > /dev/null 2>&1 +python3.9 -m pip install pip --upgrade > /dev/null 2>&1 +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null 2>&1 +apt install -y python3.9-venv > /dev/null 2>&1 LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" cd $BEAM_ROOT_DIR diff --git a/playground/infrastructure/cloudbuild/playground_ci_examples.sh b/playground/infrastructure/cloudbuild/playground_ci_examples.sh index 437cc337faf7..2a63382615a5 100755 --- a/playground/infrastructure/cloudbuild/playground_ci_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_ci_examples.sh @@ -94,12 +94,12 @@ export DEBIAN_FRONTEND=noninteractive LogOutput "Installing Python environment" apt-get install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null add-apt-repository -y ppa:deadsnakes/ppa > /dev/null && apt update > /dev/null -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null -apt install --reinstall python3.8-distutils > /dev/null +apt install -y python3.9 python3-distutils python3-pip > /dev/null +apt install --reinstall python3-distutils > /dev/null pip install --upgrade google-api-python-client > /dev/null -python3.8 -m pip install pip --upgrade > /dev/null -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null -apt install python3.8-venv > /dev/null +python3.9 -m pip install pip --upgrade > /dev/null +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null +apt install python3.9-venv > /dev/null LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" pip install -r $BEAM_ROOT_DIR/playground/infrastructure/requirements.txt diff --git a/release/src/main/Dockerfile b/release/src/main/Dockerfile index 14fe6fdb5a49..6503c5c42ba8 100644 --- a/release/src/main/Dockerfile +++ b/release/src/main/Dockerfile @@ -42,12 +42,11 @@ RUN curl https://pyenv.run | bash && \ echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> /root/.bashrc && \ echo ''eval "$(pyenv init -)"'' >> /root/.bashrc && \ source /root/.bashrc && \ - pyenv install 3.8.9 && \ pyenv install 3.9.4 && \ pyenv install 3.10.7 && \ pyenv install 3.11.3 && \ pyenv install 3.12.3 && \ - pyenv global 3.8.9 3.9.4 3.10.7 3.11.3 3.12.3 + pyenv global 3.9.4 3.10.7 3.11.3 3.12.3 # Install a Go version >= 1.16 so we can bootstrap higher # Go versions diff --git a/release/src/main/python-release/python_release_automation.sh b/release/src/main/python-release/python_release_automation.sh index 2f6986885a96..248bdd9b65ac 100755 --- a/release/src/main/python-release/python_release_automation.sh +++ b/release/src/main/python-release/python_release_automation.sh @@ -19,7 +19,7 @@ source release/src/main/python-release/run_release_candidate_python_quickstart.sh source release/src/main/python-release/run_release_candidate_python_mobile_gaming.sh -for version in 3.8 3.9 3.10 3.11 3.12 +for version in 3.9 3.10 3.11 3.12 do run_release_candidate_python_quickstart "tar" "python${version}" run_release_candidate_python_mobile_gaming "tar" "python${version}" diff --git a/sdks/python/apache_beam/__init__.py b/sdks/python/apache_beam/__init__.py index 6e08083bc0de..af88934b0e71 100644 --- a/sdks/python/apache_beam/__init__.py +++ b/sdks/python/apache_beam/__init__.py @@ -70,17 +70,11 @@ import warnings if sys.version_info.major == 3: - if sys.version_info.minor <= 7 or sys.version_info.minor >= 13: + if sys.version_info.minor <= 8 or sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) - elif sys.version_info.minor == 8: - warnings.warn( - 'Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') pass else: raise RuntimeError( diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile index 5abbffdc5a2a..f27abbfd0051 100644 --- a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -40,7 +40,7 @@ RUN pip install openai vllm RUN apt install libcairo2-dev pkg-config python3-dev -y RUN pip install pycairo -# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. +# Copy the Apache Beam worker dependencies from the Beam Python 3.12 SDK image. COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam # Set the entrypoint to Apache Beam SDK worker launcher. diff --git a/sdks/python/expansion-service-container/Dockerfile b/sdks/python/expansion-service-container/Dockerfile index 5a5ef0f410bc..4e82165f594c 100644 --- a/sdks/python/expansion-service-container/Dockerfile +++ b/sdks/python/expansion-service-container/Dockerfile @@ -17,7 +17,7 @@ ############################################################################### # We just need to support one Python version supported by Beam. -# Picking the current default Beam Python version which is Python 3.8. +# Picking the current default Beam Python version which is Python 3.9. FROM python:3.9-bookworm as expansion-service LABEL Author "Apache Beam " ARG TARGETOS diff --git a/sdks/python/setup.py b/sdks/python/setup.py index cac27db69803..9ae5d3153f51 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -155,7 +155,7 @@ def cythonize(*args, **kwargs): # Exclude 1.5.0 and 1.5.1 because of # https://github.com/pandas-dev/pandas/issues/45725 dataframe_dependency = [ - 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3;python_version>="3.8"', + 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3', ] @@ -271,18 +271,13 @@ def get_portability_package_data(): return files -python_requires = '>=3.8' +python_requires = '>=3.9' -if sys.version_info.major == 3 and sys.version_info.minor >= 12: +if sys.version_info.major == 3 and sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) -elif sys.version_info.major == 3 and sys.version_info.minor == 8: - warnings.warn('Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') if __name__ == '__main__': # In order to find the tree of proto packages, the directory @@ -534,7 +529,6 @@ def get_portability_package_data(): 'Intended Audience :: End Users/Desktop', 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', From 1a936c5d3ba2cae60fcb6874f49b5115d4838ead Mon Sep 17 00:00:00 2001 From: claudevdm <33973061+claudevdm@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:06:21 -0400 Subject: [PATCH 78/82] Move biquery enrichment transform notebook to examples/notebooks/beam-ml (#32888) Co-authored-by: Claude --- .../notebooks/beam-ml/bigquery_enrichment_transform.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename bigquery_enrichment_transform.ipynb => examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb (100%) diff --git a/bigquery_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb similarity index 100% rename from bigquery_enrichment_transform.ipynb rename to examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb From 56696ecd21ddd3b65c6e8bc70c0023bdf991dbe3 Mon Sep 17 00:00:00 2001 From: Damon Date: Mon, 21 Oct 2024 13:06:18 -0700 Subject: [PATCH 79/82] Enable BuildKit on gradle docker task (#32875) * Enable BuildKit on gradle docker task * Revert setting dockerTag --- .../main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy | 1 + 1 file changed, 1 insertion(+) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy index cd46c1270f83..388069a03983 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy @@ -130,6 +130,7 @@ class BeamDockerPlugin implements Plugin { group = 'Docker' description = 'Builds Docker image.' dependsOn prepare + environment 'DOCKER_BUILDKIT', '1' }) Task tag = project.tasks.create('dockerTag', { From 3767eda41a00d3db5044e7b339fe17d64e5585ca Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Mon, 21 Oct 2024 15:50:39 -0700 Subject: [PATCH 80/82] [prism] Dev prism builds for python and Python Direct Runner fallbacks. (#32876) --- .../runners/direct/direct_runner.py | 56 +++++++++ .../runners/portability/prism_runner.py | 112 +++++++++++++----- 2 files changed, 140 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 49b6622816ce..8b8937653688 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -110,6 +110,38 @@ def visit_transform(self, applied_ptransform): if timer.time_domain == TimeDomain.REAL_TIME: self.supported_by_fnapi_runner = False + class _PrismRunnerSupportVisitor(PipelineVisitor): + """Visitor determining if a Pipeline can be run on the PrismRunner.""" + def accept(self, pipeline): + self.supported_by_prism_runner = True + pipeline.visit(self) + return self.supported_by_prism_runner + + def visit_transform(self, applied_ptransform): + transform = applied_ptransform.transform + # Python SDK assumes the direct runner TestStream implementation is + # being used. + if isinstance(transform, TestStream): + self.supported_by_prism_runner = False + if isinstance(transform, beam.ParDo): + dofn = transform.dofn + # It's uncertain if the Prism Runner supports execution of CombineFns + # with deferred side inputs. + if isinstance(dofn, CombineValuesDoFn): + args, kwargs = transform.raw_side_inputs + args_to_check = itertools.chain(args, kwargs.values()) + if any(isinstance(arg, ArgumentPlaceholder) + for arg in args_to_check): + self.supported_by_prism_runner = False + if userstate.is_stateful_dofn(dofn): + # https://github.com/apache/beam/issues/32786 - + # Remove once Real time clock is used. + _, timer_specs = userstate.get_dofn_specs(dofn) + for timer in timer_specs: + if timer.time_domain == TimeDomain.REAL_TIME: + self.supported_by_prism_runner = False + + tryingPrism = False # Check whether all transforms used in the pipeline are supported by the # FnApiRunner, and the pipeline was not meant to be run as streaming. if _FnApiRunnerSupportVisitor().accept(pipeline): @@ -122,9 +154,33 @@ def visit_transform(self, applied_ptransform): beam_provision_api_pb2.ProvisionInfo( pipeline_options=encoded_options)) runner = fn_runner.FnApiRunner(provision_info=provision_info) + elif _PrismRunnerSupportVisitor().accept(pipeline): + _LOGGER.info('Running pipeline with PrismRunner.') + from apache_beam.runners.portability import prism_runner + runner = prism_runner.PrismRunner() + tryingPrism = True else: runner = BundleBasedDirectRunner() + if tryingPrism: + try: + pr = runner.run_pipeline(pipeline, options) + # This is non-blocking, so if the state is *already* finished, something + # probably failed on job submission. + if pr.state.is_terminal() and pr.state != PipelineState.DONE: + _LOGGER.info( + 'Pipeline failed on PrismRunner, falling back toDirectRunner.') + runner = BundleBasedDirectRunner() + else: + return pr + except Exception as e: + # If prism fails in Preparing the portable job, then the PortableRunner + # code raises an exception. Catch it, log it, and use the Direct runner + # instead. + _LOGGER.info('Exception with PrismRunner:\n %s\n' % (e)) + _LOGGER.info('Falling back to DirectRunner') + runner = BundleBasedDirectRunner() + return runner.run_pipeline(pipeline, options) diff --git a/sdks/python/apache_beam/runners/portability/prism_runner.py b/sdks/python/apache_beam/runners/portability/prism_runner.py index eeccaf5748ce..77dc8a214e8e 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner.py @@ -27,6 +27,7 @@ import platform import shutil import stat +import subprocess import typing import urllib import zipfile @@ -167,38 +168,93 @@ def construct_download_url(self, root_tag: str, sys: str, mach: str) -> str: def path_to_binary(self) -> str: if self._path is not None: - if not os.path.exists(self._path): - url = urllib.parse.urlparse(self._path) - if not url.scheme: - raise ValueError( - 'Unable to parse binary URL "%s". If using a full URL, make ' - 'sure the scheme is specified. If using a local file xpath, ' - 'make sure the file exists; you may have to first build prism ' - 'using `go build `.' % (self._path)) - - # We have a URL, see if we need to construct a valid file name. - if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): - # If this URL starts with the download prefix, let it through. - return self._path - # The only other valid option is a github release page. - if not self._path.startswith(GITHUB_TAG_PREFIX): - raise ValueError( - 'Provided --prism_location URL is not an Apache Beam Github ' - 'Release page URL or download URL: %s' % (self._path)) - # Get the root tag for this URL - root_tag = os.path.basename(os.path.normpath(self._path)) - return self.construct_download_url( - root_tag, platform.system(), platform.machine()) - return self._path - else: - if '.dev' in self._version: + # The path is overidden, check various cases. + if os.path.exists(self._path): + # The path is local and exists, use directly. + return self._path + + # Check if the path is a URL. + url = urllib.parse.urlparse(self._path) + if not url.scheme: + raise ValueError( + 'Unable to parse binary URL "%s". If using a full URL, make ' + 'sure the scheme is specified. If using a local file xpath, ' + 'make sure the file exists; you may have to first build prism ' + 'using `go build `.' % (self._path)) + + # We have a URL, see if we need to construct a valid file name. + if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): + # If this URL starts with the download prefix, let it through. + return self._path + # The only other valid option is a github release page. + if not self._path.startswith(GITHUB_TAG_PREFIX): raise ValueError( - 'Unable to derive URL for dev versions "%s". Please provide an ' - 'alternate version to derive the release URL with the ' - '--prism_beam_version_override flag.' % (self._version)) + 'Provided --prism_location URL is not an Apache Beam Github ' + 'Release page URL or download URL: %s' % (self._path)) + # Get the root tag for this URL + root_tag = os.path.basename(os.path.normpath(self._path)) + return self.construct_download_url( + root_tag, platform.system(), platform.machine()) + + if '.dev' not in self._version: + # Not a development version, so construct the production download URL return self.construct_download_url( self._version, platform.system(), platform.machine()) + # This is a development version! Assume Go is installed. + # Set the install directory to the cache location. + envdict = {**os.environ, "GOBIN": self.BIN_CACHE} + PRISMPKG = "github.com/apache/beam/sdks/v2/go/cmd/prism" + + process = subprocess.run(["go", "install", PRISMPKG], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + if process.returncode == 0: + # Successfully installed + return '%s/prism' % (self.BIN_CACHE) + + # We failed to build for some reason. + output = process.stdout.decode("utf-8") + if ("not in a module" not in output) and ( + "no required module provides" not in output): + # This branch handles two classes of failures: + # 1. Go isn't installed, so it needs to be installed by the Beam SDK + # developer. + # 2. Go is installed, and they are building in a local version of Prism, + # but there was a compile error that the developer should address. + # Either way, the @latest fallback either would fail, or hide the error, + # so fail now. + _LOGGER.info(output) + raise ValueError( + 'Unable to install a local of Prism: "%s";\n' + 'Likely Go is not installed, or a local change to Prism did not ' + 'compile.\nPlease install Go (see https://go.dev/doc/install) to ' + 'enable automatic local builds.\n' + 'Alternatively provide a binary with the --prism_location flag.' + '\nCaptured output:\n %s' % (self._version, output)) + + # Go is installed and claims we're not in a Go module that has access to + # the Prism package. + + # Fallback to using the @latest version of prism, which works everywhere. + process = subprocess.run(["go", "install", PRISMPKG + "@latest"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + + if process.returncode == 0: + return '%s/prism' % (self.BIN_CACHE) + + output = process.stdout.decode("utf-8") + raise ValueError( + 'We were unable to execute the subprocess "%s" to automatically ' + 'build prism.\nAlternatively provide an alternate binary with the ' + '--prism_location flag.' + '\nCaptured output:\n %s' % (process.args, output)) + def subprocess_cmd_and_endpoint( self) -> typing.Tuple[typing.List[typing.Any], str]: bin_path = self.local_bin( From 8d22fc2f72e6d0781eea465e773c542b5907686d Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 22 Oct 2024 00:57:48 -0700 Subject: [PATCH 81/82] Remove experiments guarding isolated channels enablement based on jobsettings (#32782) --- .../worker/StreamingDataflowWorker.java | 7 +------ .../client/grpc/GrpcDispatcherClient.java | 19 +++++-------------- .../client/grpc/GrpcDispatcherClientTest.java | 15 +-------------- 3 files changed, 7 insertions(+), 34 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index ecdba404151e..524906023722 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -140,8 +140,6 @@ public final class StreamingDataflowWorker { private static final int DEFAULT_STATUS_PORT = 8081; private static final Random CLIENT_ID_GENERATOR = new Random(); private static final String CHANNELZ_PATH = "/channelz"; - public static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL = - "streaming_engine_use_job_settings_for_heartbeat_pool"; private final WindmillStateCache stateCache; private final StreamingWorkerStatusPages statusPages; @@ -249,10 +247,7 @@ private StreamingDataflowWorker( GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream); getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); - // Experiment gates the logic till backend changes are rollback safe - if (!DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL) - || options.getUseSeparateWindmillHeartbeatStreams() != null) { + if (options.getUseSeparateWindmillHeartbeatStreams() != null) { heartbeatSender = StreamPoolHeartbeatSender.Create( Boolean.TRUE.equals(options.getUseSeparateWindmillHeartbeatStreams()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index f96464150d4a..6bae84483d16 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -30,7 +30,6 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; @@ -53,8 +52,6 @@ public class GrpcDispatcherClient { private static final Logger LOG = LoggerFactory.getLogger(GrpcDispatcherClient.class); - static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS = - "streaming_engine_use_job_settings_for_isolated_channels"; private final CountDownLatch onInitializedEndpoints; /** @@ -80,18 +77,12 @@ private GrpcDispatcherClient( DispatcherStubs initialDispatcherStubs, Random rand) { this.windmillStubFactoryFactory = windmillStubFactoryFactory; - if (DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)) { - if (options.getUseWindmillIsolatedChannels() != null) { - this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); - this.reactToIsolatedChannelsJobSetting = false; - } else { - this.useIsolatedChannels.set(false); - this.reactToIsolatedChannelsJobSetting = true; - } - } else { - this.useIsolatedChannels.set(Boolean.TRUE.equals(options.getUseWindmillIsolatedChannels())); + if (options.getUseWindmillIsolatedChannels() != null) { + this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); this.reactToIsolatedChannelsJobSetting = false; + } else { + this.useIsolatedChannels.set(false); + this.reactToIsolatedChannelsJobSetting = true; } this.windmillStubFactory.set( windmillStubFactoryFactory.makeWindmillStubFactory(useIsolatedChannels.get())); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java index 3f746d91a868..c04456906ea2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java @@ -34,7 +34,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.hamcrest.Matcher; import org.junit.Test; @@ -55,9 +54,6 @@ public static class RespectsJobSettingTest { public void createsNewStubWhenIsolatedChannelsConfigIsChanged() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); // Create first time with Isolated channels disabled @@ -91,27 +87,18 @@ public static class RespectsPipelineOptionsTest { public static Collection data() { List list = new ArrayList<>(); for (Boolean pipelineOption : new Boolean[] {true, false}) { - list.add(new Object[] {/*experimentEnabled=*/ false, pipelineOption}); - list.add(new Object[] {/*experimentEnabled=*/ true, pipelineOption}); + list.add(new Object[] {pipelineOption}); } return list; } @Parameter(0) - public Boolean experimentEnabled; - - @Parameter(1) public Boolean pipelineOption; @Test public void ignoresIsolatedChannelsConfigWithPipelineOption() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - if (experimentEnabled) { - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); - } options.setUseWindmillIsolatedChannels(pipelineOption); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); From f28ca3ca10647ae5d98cfe65a196929c6972711d Mon Sep 17 00:00:00 2001 From: Damon Date: Tue, 22 Oct 2024 09:58:16 -0700 Subject: [PATCH 82/82] Add target parameter to BeamDockerPlugin (#32890) --- .../groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy index 388069a03983..b3949223f074 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy @@ -59,6 +59,7 @@ class BeamDockerPlugin implements Plugin { boolean load = false boolean push = false String builder = null + String target = null File resolvedDockerfile = null File resolvedDockerComposeTemplate = null @@ -289,6 +290,9 @@ class BeamDockerPlugin implements Plugin { } else { buildCommandLine.addAll(['-t', "${-> ext.name}", '.']) } + if (ext.target != null && ext.target != "") { + buildCommandLine.addAll '--target', ext.target + } logger.debug("${buildCommandLine}" as String) return buildCommandLine }