Skip to content

Commit

Permalink
[YAML] Increase re-use of providers with implicitly overlapping trans…
Browse files Browse the repository at this point in the history
…forms. (apache#30793)

We use the schema transform discovery service to determine all the
transforms that a given provider vends which may be larger than those
that were explicitly declared (e.g. the basic mapping transforms
are part of java core) and use this to expand the possible set of
transforms this provider can be used to instanciate (attaching the
appropreate renaming, etc. as needed). This can greatly increase the
chances that we can use the same provider, and hence same SDK
and environment, for adjacent steps.

This is done lazily as transforms are used to avoid instanciating
all possible providers for all pipelines.
  • Loading branch information
robertwb authored May 1, 2024
1 parent 44e2abe commit dc5841d
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 8 deletions.
61 changes: 60 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def config_schema(self, type):
def description(self, type):
return None

def __repr__(self):
ts = ','.join(sorted(self.provided_transforms()))
return f'{self.__class__.__name__}[{ts}]'

def to_json(self):
return {'type': self.__class__.__name__}

def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool:
"""Returns whether this transform requires inputs.
Expand Down Expand Up @@ -107,6 +114,16 @@ def underlying_provider(self):
"""
return self

def with_underlying_provider(self, provider):
"""Returns a provider that vends as many transforms of this provider as
possible but uses the given underlying provider.
This is used to try to minimize the number of distinct environments that
are used as many providers implicitly or explicitly link in many common
transforms.
"""
return None

def affinity(self, other: "Provider"):
"""Returns a value approximating how good it would be for this provider
to be used immediately following a transform from the other provider
Expand Down Expand Up @@ -154,6 +171,17 @@ def __init__(self, urns, service):
def provided_transforms(self):
return self._urns.keys()

def with_underlying_provider(self, other_provider):
if not isinstance(other_provider, ExternalProvider):
return

all_urns = set(other_provider.schema_transforms().keys()).union(
set(other_provider._urns.values()))
for name, urn in self._urns.items():
if urn in all_urns and name not in other_provider._urns:
other_provider._urns[name] = urn
return other_provider

def schema_transforms(self):
if callable(self._service):
self._service = self._service()
Expand Down Expand Up @@ -520,6 +548,15 @@ def cache_artifacts(self):
def underlying_provider(self):
return self.sql_provider()

def with_underlying_provider(self, other_provider):
new_sql_provider = self.sql_provider().with_underlying_provider(
other_provider)
if new_sql_provider is None:
new_sql_provider = other_provider
if new_sql_provider and 'Sql' in set(
new_sql_provider.provided_transforms()):
return SqlBackedProvider(self._transforms, new_sql_provider)

def to_json(self):
return {'type': "SqlBackedProvider"}

Expand Down Expand Up @@ -1007,7 +1044,12 @@ def __exit__(self, *args):

@ExternalProvider.register_provider_type('renaming')
class RenamingProvider(Provider):
def __init__(self, transforms, mappings, underlying_provider, defaults=None):
def __init__(
self,
transforms: Mapping[str, str],
mappings: Mapping[str, Mapping[str, str]],
underlying_provider: Provider,
defaults: Optional[Mapping[str, Mapping[str, Any]]] = None):
if isinstance(underlying_provider, dict):
underlying_provider = ExternalProvider.provider_from_spec(
underlying_provider)
Expand Down Expand Up @@ -1109,6 +1151,23 @@ def _affinity(self, other):
def underlying_provider(self):
return self._underlying_provider.underlying_provider()

def with_underlying_provider(self, other_provider):
new_underlying_provider = (
self._underlying_provider.with_underlying_provider(other_provider))
if new_underlying_provider:
provided = set(new_underlying_provider.provided_transforms())
new_transforms = {
alias: actual
for alias,
actual in self._transforms.items() if actual in provided
}
if new_transforms:
return RenamingProvider(
new_transforms,
self._mappings,
new_underlying_provider,
self._defaults)

def cache_artifacts(self):
self._underlying_provider.cache_artifacts()

Expand Down
120 changes: 120 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_provider_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import unittest

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform


class FakeTransform(beam.PTransform):
def __init__(self, creator, urn):
self.creator = creator
self.urn = urn

def expand(self, pcoll):
return pcoll | beam.Map(lambda x: x + ((self.urn, self.creator), ))


class FakeExternalProvider(yaml_provider.ExternalProvider):
def __init__(
self, id, known_transforms, extra_transform_urns, error_on_use=False):
super().__init__(known_transforms, None)
self._id = id
self._schema_transforms = {urn: None for urn in extra_transform_urns}
self._error_on_use = error_on_use

def create_transform(self, type, *unused_args, **unused_kwargs):
if self._error_on_use:
raise RuntimeError(f'Provider {self._id} should not be used.')
return FakeTransform(self._id, self._urns[type])

def available(self):
# Claim we're available even if we error on use.
return True


class YamlProvidersTest(unittest.TestCase):
def test_external_with_underlying_provider(self):
providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn'])
providerB = FakeExternalProvider("B", {'B': 'b:urn', 'alias': 'a:urn'}, [])
newA = providerB.with_underlying_provider(providerA)

self.assertIn('B', list(newA.provided_transforms()))
t = newA.create_transform('B')
self.assertEqual('A', t.creator)
self.assertEqual('b:urn', t.urn)

self.assertIn('alias', list(newA.provided_transforms()))
t = newA.create_transform('alias')
self.assertEqual('A', t.creator)
self.assertEqual('a:urn', t.urn)

def test_renaming_with_underlying_provider(self):
providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn'])
providerB = FakeExternalProvider("B", {'B': 'b:urn', 'C': 'c:urn'}, [])
providerR = yaml_provider.RenamingProvider( # keep wrapping
{'RenamedB': 'B', 'RenamedC': 'C' },
{'RenamedB': {}, 'RenamedC': {}},
providerB)

newR = providerR.with_underlying_provider(providerA)
self.assertIn('RenamedB', list(newR.provided_transforms()))
self.assertNotIn('RenamedC', list(newR.provided_transforms()))
t = newR.create_transform('RenamedB', {}, None)
self.assertEqual('A', t.creator)
self.assertEqual('b:urn', t.urn)

def test_extended_providers_reused(self):
providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn'])
providerB = FakeExternalProvider(
"B", {
'B': 'b:urn', 'C': 'c:urn'
}, [], error_on_use=True)
providerR = yaml_provider.RenamingProvider( # keep wrapping
{'RenamedB': 'B', 'RenamedC': 'C' },
{'RenamedB': {}, 'RenamedC': {}},
providerB)

with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | 'Yaml1' >> yaml_transform.YamlTransform(
'''
type: chain
transforms:
- type: Create
config:
elements: [0]
- type: A
- type: B
- type: RenamedB
''',
providers=[providerA, providerB, providerR])
# All of these transforms should be serviced by providerA,
# negating the need to invoke providerB.
assert_that(
result,
equal_to([(0, ('a:urn', 'A'), ('b:urn', 'A'), ('b:urn', 'A'))]))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
50 changes: 45 additions & 5 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import List
from typing import Mapping
from typing import Set
from typing import Tuple

import yaml
from yaml.loader import SafeLoader
Expand Down Expand Up @@ -178,6 +179,46 @@ def get_line(obj):
return getattr(obj, '_line_', 'unknown')


class ProviderSet:
def __init__(self, providers: Mapping[str, Iterable[yaml_provider.Provider]]):
self._providers = providers
self._original_providers = set(p for ps in providers.values() for p in ps)
self._expanded: Set[yaml_provider.Provider] = set()

def __contains__(self, key: str):
return key in self._providers

def __getitem__(self, key: str) -> Iterable[yaml_provider.Provider]:
return self._providers[key]

def items(self) -> Iterable[Tuple[str, Iterable[yaml_provider.Provider]]]:
return self._providers.items()

def use(self, provider):
underlying_provider = provider.underlying_provider()
if provider not in self._original_providers:
return
if underlying_provider in self._expanded:
return
self._expanded.add(underlying_provider)

# For every original provider, see if it can also be implemented in terms
# of this provider.
# This is useful because providers often vend everything that's been linked
# in, and we'd like to minimize provider boundary crossings.
# We do this lazily rather than when constructing the provider set to avoid
# instantiating every possible provider.
for other_provider in self._original_providers:
new_provider = other_provider.with_underlying_provider(
underlying_provider)
if new_provider is not None:
for t in new_provider.provided_transforms():
existing_providers = set(
p.underlying_provider() for p in self._providers[t])
if new_provider.underlying_provider() not in existing_providers:
self._providers[t].append(new_provider)


class LightweightScope(object):
def __init__(self, transforms):
self._transforms = transforms
Expand Down Expand Up @@ -220,7 +261,7 @@ def __init__(
root,
inputs: Mapping[str, Any],
transforms: Iterable[dict],
providers: Mapping[str, Iterable[yaml_provider.Provider]],
providers: ProviderSet,
input_providers: Iterable[yaml_provider.Provider]):
super().__init__(transforms)
self.root = root
Expand Down Expand Up @@ -371,6 +412,7 @@ def create_ptransform(self, spec, input_pcolls):
if pcoll in providers_by_input
]
provider = self.best_provider(spec, input_providers)
self.providers.use(provider)

config = SafeLineLoader.strip_metadata(spec.get('config', {}))
if not isinstance(config, dict):
Expand Down Expand Up @@ -509,9 +551,7 @@ def expand_composite_transform(spec, scope):
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
},
spec['transforms'],
yaml_provider.merge_providers(
yaml_provider.parse_providers(spec.get('providers', [])),
scope.providers),
scope.providers,
scope.input_providers)

class CompositePTransform(beam.PTransform):
Expand Down Expand Up @@ -1012,7 +1052,7 @@ def expand(self, pcolls):
root,
pcolls,
transforms=[self._spec],
providers=self._providers,
providers=ProviderSet(self._providers),
input_providers={
pcoll: python_provider
for pcoll in pcolls.values()
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_scope_by_spec(self, p, spec):
scope = Scope(
beam.pvalue.PBegin(p), {},
spec['transforms'],
yaml_provider.standard_providers(), {})
yaml_transform.ProviderSet(yaml_provider.standard_providers()), {})
return scope, spec

def test_get_pcollection_input(self):
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from apache_beam.yaml import YamlTransform
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_provider import InlineProvider
from apache_beam.yaml.yaml_transform import ProviderSet
from apache_beam.yaml.yaml_transform import SafeLineLoader
from apache_beam.yaml.yaml_transform import Scope
from apache_beam.yaml.yaml_transform import chain_as_composite
Expand Down Expand Up @@ -119,7 +120,7 @@ def get_scope_by_spec(self, p, spec, inputs=None):
beam.pvalue.PBegin(p),
inputs,
spec['transforms'],
yaml_provider.standard_providers(), {})
ProviderSet(yaml_provider.standard_providers()), {})
return scope, spec

def test_pipeline_as_composite_with_type_transforms(self):
Expand Down

0 comments on commit dc5841d

Please sign in to comment.