diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index ae98449a1cdf..0a7c3a758911 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -78,6 +78,13 @@ def config_schema(self, type): def description(self, type): return None + def __repr__(self): + ts = ','.join(sorted(self.provided_transforms())) + return f'{self.__class__.__name__}[{ts}]' + + def to_json(self): + return {'type': self.__class__.__name__} + def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool: """Returns whether this transform requires inputs. @@ -105,6 +112,16 @@ def underlying_provider(self): """ return self + def with_underlying_provider(self, provider): + """Returns a provider that vends as many transforms of this provider as + possible but uses the given underlying provider. + + This is used to try to minimize the number of distinct environments that + are used as many providers implicitly or explicitly link in many common + transforms. + """ + return None + def affinity(self, other: "Provider"): """Returns a value approximating how good it would be for this provider to be used immediately following a transform from the other provider @@ -152,6 +169,17 @@ def __init__(self, urns, service): def provided_transforms(self): return self._urns.keys() + def with_underlying_provider(self, other_provider): + if not isinstance(other_provider, ExternalProvider): + return + + all_urns = set(other_provider.schema_transforms().keys()).union( + set(other_provider._urns.values())) + for name, urn in self._urns.items(): + if urn in all_urns and name not in other_provider._urns: + other_provider._urns[name] = urn + return other_provider + def schema_transforms(self): if callable(self._service): self._service = self._service() @@ -518,6 +546,15 @@ def cache_artifacts(self): def underlying_provider(self): return self.sql_provider() + def with_underlying_provider(self, other_provider): + new_sql_provider = self.sql_provider().with_underlying_provider( + other_provider) + if new_sql_provider is None: + new_sql_provider = other_provider + if new_sql_provider and 'Sql' in set( + new_sql_provider.provided_transforms()): + return SqlBackedProvider(self._transforms, new_sql_provider) + def to_json(self): return {'type': "SqlBackedProvider"} @@ -995,7 +1032,12 @@ def __exit__(self, *args): @ExternalProvider.register_provider_type('renaming') class RenamingProvider(Provider): - def __init__(self, transforms, mappings, underlying_provider, defaults=None): + def __init__( + self, + transforms: Mapping[str, str], + mappings: Mapping[str, Mapping[str, str]], + underlying_provider: Provider, + defaults: Optional[Mapping[str, Mapping[str, Any]]] = None): if isinstance(underlying_provider, dict): underlying_provider = ExternalProvider.provider_from_spec( underlying_provider) @@ -1097,6 +1139,23 @@ def _affinity(self, other): def underlying_provider(self): return self._underlying_provider.underlying_provider() + def with_underlying_provider(self, other_provider): + new_underlying_provider = ( + self._underlying_provider.with_underlying_provider(other_provider)) + if new_underlying_provider: + provided = set(new_underlying_provider.provided_transforms()) + new_transforms = { + alias: actual + for alias, + actual in self._transforms.items() if actual in provided + } + if new_transforms: + return RenamingProvider( + new_transforms, + self._mappings, + new_underlying_provider, + self._defaults) + def cache_artifacts(self): self._underlying_provider.cache_artifacts() diff --git a/sdks/python/apache_beam/yaml/yaml_provider_test.py b/sdks/python/apache_beam/yaml/yaml_provider_test.py new file mode 100644 index 000000000000..30483eb96177 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_provider_test.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import apache_beam as beam +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.yaml import yaml_provider +from apache_beam.yaml import yaml_transform + + +class FakeTransform(beam.PTransform): + def __init__(self, creator, urn): + self.creator = creator + self.urn = urn + + def expand(self, pcoll): + return pcoll | beam.Map(lambda x: x + ((self.urn, self.creator), )) + + +class FakeExternalProvider(yaml_provider.ExternalProvider): + def __init__( + self, id, known_transforms, extra_transform_urns, error_on_use=False): + super().__init__(known_transforms, None) + self._id = id + self._schema_transforms = {urn: None for urn in extra_transform_urns} + self._error_on_use = error_on_use + + def create_transform(self, type, *unused_args, **unused_kwargs): + if self._error_on_use: + raise RuntimeError(f'Provider {self._id} should not be used.') + return FakeTransform(self._id, self._urns[type]) + + def available(self): + # Claim we're available even if we error on use. + return True + + +class YamlProvidersTest(unittest.TestCase): + def test_external_with_underlying_provider(self): + providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn']) + providerB = FakeExternalProvider("B", {'B': 'b:urn', 'alias': 'a:urn'}, []) + newA = providerB.with_underlying_provider(providerA) + + self.assertIn('B', list(newA.provided_transforms())) + t = newA.create_transform('B') + self.assertEqual('A', t.creator) + self.assertEqual('b:urn', t.urn) + + self.assertIn('alias', list(newA.provided_transforms())) + t = newA.create_transform('alias') + self.assertEqual('A', t.creator) + self.assertEqual('a:urn', t.urn) + + def test_renaming_with_underlying_provider(self): + providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn']) + providerB = FakeExternalProvider("B", {'B': 'b:urn', 'C': 'c:urn'}, []) + providerR = yaml_provider.RenamingProvider( # keep wrapping + {'RenamedB': 'B', 'RenamedC': 'C' }, + {'RenamedB': {}, 'RenamedC': {}}, + providerB) + + newR = providerR.with_underlying_provider(providerA) + self.assertIn('RenamedB', list(newR.provided_transforms())) + self.assertNotIn('RenamedC', list(newR.provided_transforms())) + t = newR.create_transform('RenamedB', {}, None) + self.assertEqual('A', t.creator) + self.assertEqual('b:urn', t.urn) + + def test_extended_providers_reused(self): + providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn']) + providerB = FakeExternalProvider( + "B", { + 'B': 'b:urn', 'C': 'c:urn' + }, [], error_on_use=True) + providerR = yaml_provider.RenamingProvider( # keep wrapping + {'RenamedB': 'B', 'RenamedC': 'C' }, + {'RenamedB': {}, 'RenamedC': {}}, + providerB) + + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | 'Yaml1' >> yaml_transform.YamlTransform( + ''' + type: chain + transforms: + - type: Create + config: + elements: [0] + - type: A + - type: B + - type: RenamedB + ''', + providers=[providerA, providerB, providerR]) + # All of these transforms should be serviced by providerA, + # negating the need to invoke providerB. + assert_that( + result, + equal_to([(0, ('a:urn', 'A'), ('b:urn', 'A'), ('b:urn', 'A'))])) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 03574b5f98ff..dffb8108e362 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -28,6 +28,7 @@ from typing import List from typing import Mapping from typing import Set +from typing import Tuple import yaml from yaml.loader import SafeLoader @@ -178,6 +179,46 @@ def get_line(obj): return getattr(obj, '_line_', 'unknown') +class ProviderSet: + def __init__(self, providers: Mapping[str, Iterable[yaml_provider.Provider]]): + self._providers = providers + self._original_providers = set(p for ps in providers.values() for p in ps) + self._expanded: Set[yaml_provider.Provider] = set() + + def __contains__(self, key: str): + return key in self._providers + + def __getitem__(self, key: str) -> Iterable[yaml_provider.Provider]: + return self._providers[key] + + def items(self) -> Iterable[Tuple[str, Iterable[yaml_provider.Provider]]]: + return self._providers.items() + + def use(self, provider): + underlying_provider = provider.underlying_provider() + if provider not in self._original_providers: + return + if underlying_provider in self._expanded: + return + self._expanded.add(underlying_provider) + + # For every original provider, see if it can also be implemented in terms + # of this provider. + # This is useful because providers often vend everything that's been linked + # in, and we'd like to minimize provider boundary crossings. + # We do this lazily rather than when constructing the provider set to avoid + # instantiating every possible provider. + for other_provider in self._original_providers: + new_provider = other_provider.with_underlying_provider( + underlying_provider) + if new_provider is not None: + for t in new_provider.provided_transforms(): + existing_providers = set( + p.underlying_provider() for p in self._providers[t]) + if new_provider.underlying_provider() not in existing_providers: + self._providers[t].append(new_provider) + + class LightweightScope(object): def __init__(self, transforms): self._transforms = transforms @@ -220,7 +261,7 @@ def __init__( root, inputs: Mapping[str, Any], transforms: Iterable[dict], - providers: Mapping[str, Iterable[yaml_provider.Provider]], + providers: ProviderSet, input_providers: Iterable[yaml_provider.Provider]): super().__init__(transforms) self.root = root @@ -371,6 +412,7 @@ def create_ptransform(self, spec, input_pcolls): if pcoll in providers_by_input ] provider = self.best_provider(spec, input_providers) + self.providers.use(provider) config = SafeLineLoader.strip_metadata(spec.get('config', {})) if not isinstance(config, dict): @@ -509,9 +551,7 @@ def expand_composite_transform(spec, scope): for (key, value) in empty_if_explicitly_empty(spec['input']).items() }, spec['transforms'], - yaml_provider.merge_providers( - yaml_provider.parse_providers(spec.get('providers', [])), - scope.providers), + scope.providers, scope.input_providers) class CompositePTransform(beam.PTransform): @@ -1005,7 +1045,7 @@ def expand(self, pcolls): root, pcolls, transforms=[self._spec], - providers=self._providers, + providers=ProviderSet(self._providers), input_providers={ pcoll: python_provider for pcoll in pcolls.values() diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index f00403b07e2a..82a8d38df3cb 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -36,7 +36,7 @@ def get_scope_by_spec(self, p, spec): scope = Scope( beam.pvalue.PBegin(p), {}, spec['transforms'], - yaml_provider.standard_providers(), {}) + yaml_transform.ProviderSet(yaml_provider.standard_providers()), {}) return scope, spec def test_get_pcollection_input(self): diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 39f31619a741..c30472a399e7 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -23,6 +23,7 @@ from apache_beam.yaml import YamlTransform from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import InlineProvider +from apache_beam.yaml.yaml_transform import ProviderSet from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import Scope from apache_beam.yaml.yaml_transform import chain_as_composite @@ -119,7 +120,7 @@ def get_scope_by_spec(self, p, spec, inputs=None): beam.pvalue.PBegin(p), inputs, spec['transforms'], - yaml_provider.standard_providers(), {}) + ProviderSet(yaml_provider.standard_providers()), {}) return scope, spec def test_pipeline_as_composite_with_type_transforms(self):