Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Increase re-use of providers with implicitly overlapping transforms. #30793

Merged
merged 5 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def config_schema(self, type):
def description(self, type):
return None

def __repr__(self):
ts = ','.join(sorted(self.provided_transforms()))
return f'{self.__class__.__name__}[{ts}]'

def to_json(self):
return {'type': self.__class__.__name__}

def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool:
"""Returns whether this transform requires inputs.

Expand Down Expand Up @@ -105,6 +112,16 @@ def underlying_provider(self):
"""
return self

def with_underlying_provider(self, provider):
"""Returns a provider that vends as many transforms of this provider as
possible but uses the given underlying provider.

This is used to try to minimize the number of distinct environments that
are used as many providers implicitly or explicitly link in many common
transforms.
"""
return None

def affinity(self, other: "Provider"):
"""Returns a value approximating how good it would be for this provider
to be used immediately following a transform from the other provider
Expand Down Expand Up @@ -152,6 +169,17 @@ def __init__(self, urns, service):
def provided_transforms(self):
return self._urns.keys()

def with_underlying_provider(self, other_provider):
if not isinstance(other_provider, ExternalProvider):
return

all_urns = set(other_provider.schema_transforms().keys()).union(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is part of the stack trace. I'm very confident this PR is the culprit, creating a revert for now since I'm not sure what the immediate solution is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... the only thing red was coverage in the presubmits.

Do you have a rollback already or should I create one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#31169 - still waiting on presubmits to go green

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... the only thing red was coverage in the presubmits.

Yeah, I noticed this - I'm not sure why... It looks like the failures are around 2.57.0.dev, so it could be some versioning issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked that we've published snapshots and everything though

set(other_provider._urns.values()))
for name, urn in self._urns.items():
if urn in all_urns and name not in other_provider._urns:
other_provider._urns[name] = urn
return other_provider

def schema_transforms(self):
if callable(self._service):
self._service = self._service()
Expand Down Expand Up @@ -518,6 +546,15 @@ def cache_artifacts(self):
def underlying_provider(self):
return self.sql_provider()

def with_underlying_provider(self, other_provider):
new_sql_provider = self.sql_provider().with_underlying_provider(
other_provider)
if new_sql_provider is None:
new_sql_provider = other_provider
if new_sql_provider and 'Sql' in set(
new_sql_provider.provided_transforms()):
return SqlBackedProvider(self._transforms, new_sql_provider)

def to_json(self):
return {'type': "SqlBackedProvider"}

Expand Down Expand Up @@ -995,7 +1032,12 @@ def __exit__(self, *args):

@ExternalProvider.register_provider_type('renaming')
class RenamingProvider(Provider):
def __init__(self, transforms, mappings, underlying_provider, defaults=None):
def __init__(
self,
transforms: Mapping[str, str],
mappings: Mapping[str, Mapping[str, str]],
underlying_provider: Provider,
defaults: Optional[Mapping[str, Mapping[str, Any]]] = None):
if isinstance(underlying_provider, dict):
underlying_provider = ExternalProvider.provider_from_spec(
underlying_provider)
Expand Down Expand Up @@ -1097,6 +1139,23 @@ def _affinity(self, other):
def underlying_provider(self):
return self._underlying_provider.underlying_provider()

def with_underlying_provider(self, other_provider):
new_underlying_provider = (
self._underlying_provider.with_underlying_provider(other_provider))
if new_underlying_provider:
provided = set(new_underlying_provider.provided_transforms())
new_transforms = {
alias: actual
for alias,
actual in self._transforms.items() if actual in provided
}
if new_transforms:
return RenamingProvider(
new_transforms,
self._mappings,
new_underlying_provider,
self._defaults)

def cache_artifacts(self):
self._underlying_provider.cache_artifacts()

Expand Down
120 changes: 120 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_provider_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import unittest

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform


class FakeTransform(beam.PTransform):
def __init__(self, creator, urn):
self.creator = creator
self.urn = urn

def expand(self, pcoll):
return pcoll | beam.Map(lambda x: x + ((self.urn, self.creator), ))


class FakeExternalProvider(yaml_provider.ExternalProvider):
def __init__(
self, id, known_transforms, extra_transform_urns, error_on_use=False):
super().__init__(known_transforms, None)
self._id = id
self._schema_transforms = {urn: None for urn in extra_transform_urns}
self._error_on_use = error_on_use

def create_transform(self, type, *unused_args, **unused_kwargs):
if self._error_on_use:
raise RuntimeError(f'Provider {self._id} should not be used.')
return FakeTransform(self._id, self._urns[type])

def available(self):
# Claim we're available even if we error on use.
return True


class YamlProvidersTest(unittest.TestCase):
def test_external_with_underlying_provider(self):
providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn'])
providerB = FakeExternalProvider("B", {'B': 'b:urn', 'alias': 'a:urn'}, [])
newA = providerB.with_underlying_provider(providerA)

self.assertIn('B', list(newA.provided_transforms()))
t = newA.create_transform('B')
self.assertEqual('A', t.creator)
self.assertEqual('b:urn', t.urn)

self.assertIn('alias', list(newA.provided_transforms()))
t = newA.create_transform('alias')
self.assertEqual('A', t.creator)
self.assertEqual('a:urn', t.urn)

def test_renaming_with_underlying_provider(self):
providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn'])
providerB = FakeExternalProvider("B", {'B': 'b:urn', 'C': 'c:urn'}, [])
providerR = yaml_provider.RenamingProvider( # keep wrapping
{'RenamedB': 'B', 'RenamedC': 'C' },
{'RenamedB': {}, 'RenamedC': {}},
providerB)

newR = providerR.with_underlying_provider(providerA)
self.assertIn('RenamedB', list(newR.provided_transforms()))
self.assertNotIn('RenamedC', list(newR.provided_transforms()))
t = newR.create_transform('RenamedB', {}, None)
self.assertEqual('A', t.creator)
self.assertEqual('b:urn', t.urn)

def test_extended_providers_reused(self):
providerA = FakeExternalProvider("A", {'A': 'a:urn'}, ['b:urn'])
providerB = FakeExternalProvider(
"B", {
'B': 'b:urn', 'C': 'c:urn'
}, [], error_on_use=True)
providerR = yaml_provider.RenamingProvider( # keep wrapping
{'RenamedB': 'B', 'RenamedC': 'C' },
{'RenamedB': {}, 'RenamedC': {}},
providerB)

with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | 'Yaml1' >> yaml_transform.YamlTransform(
'''
type: chain
transforms:
- type: Create
config:
elements: [0]
- type: A
- type: B
- type: RenamedB
''',
providers=[providerA, providerB, providerR])
# All of these transforms should be serviced by providerA,
# negating the need to invoke providerB.
assert_that(
result,
equal_to([(0, ('a:urn', 'A'), ('b:urn', 'A'), ('b:urn', 'A'))]))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
50 changes: 45 additions & 5 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import List
from typing import Mapping
from typing import Set
from typing import Tuple

import yaml
from yaml.loader import SafeLoader
Expand Down Expand Up @@ -178,6 +179,46 @@ def get_line(obj):
return getattr(obj, '_line_', 'unknown')


class ProviderSet:
def __init__(self, providers: Mapping[str, Iterable[yaml_provider.Provider]]):
self._providers = providers
self._original_providers = set(p for ps in providers.values() for p in ps)
self._expanded: Set[yaml_provider.Provider] = set()

def __contains__(self, key: str):
return key in self._providers

def __getitem__(self, key: str) -> Iterable[yaml_provider.Provider]:
return self._providers[key]

def items(self) -> Iterable[Tuple[str, Iterable[yaml_provider.Provider]]]:
return self._providers.items()

def use(self, provider):
underlying_provider = provider.underlying_provider()
if provider not in self._original_providers:
return
if underlying_provider in self._expanded:
return
self._expanded.add(underlying_provider)

# For every original provider, see if it can also be implemented in terms
# of this provider.
# This is useful because providers often vend everything that's been linked
# in, and we'd like to minimize provider boundary crossings.
# We do this lazily rather than when constructing the provider set to avoid
# instantiating every possible provider.
for other_provider in self._original_providers:
new_provider = other_provider.with_underlying_provider(
underlying_provider)
if new_provider is not None:
for t in new_provider.provided_transforms():
existing_providers = set(
p.underlying_provider() for p in self._providers[t])
if new_provider.underlying_provider() not in existing_providers:
self._providers[t].append(new_provider)


class LightweightScope(object):
def __init__(self, transforms):
self._transforms = transforms
Expand Down Expand Up @@ -220,7 +261,7 @@ def __init__(
root,
inputs: Mapping[str, Any],
transforms: Iterable[dict],
providers: Mapping[str, Iterable[yaml_provider.Provider]],
providers: ProviderSet,
input_providers: Iterable[yaml_provider.Provider]):
super().__init__(transforms)
self.root = root
Expand Down Expand Up @@ -371,6 +412,7 @@ def create_ptransform(self, spec, input_pcolls):
if pcoll in providers_by_input
]
provider = self.best_provider(spec, input_providers)
self.providers.use(provider)

config = SafeLineLoader.strip_metadata(spec.get('config', {}))
if not isinstance(config, dict):
Expand Down Expand Up @@ -509,9 +551,7 @@ def expand_composite_transform(spec, scope):
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
},
spec['transforms'],
yaml_provider.merge_providers(
yaml_provider.parse_providers(spec.get('providers', [])),
scope.providers),
scope.providers,
scope.input_providers)

class CompositePTransform(beam.PTransform):
Expand Down Expand Up @@ -1005,7 +1045,7 @@ def expand(self, pcolls):
root,
pcolls,
transforms=[self._spec],
providers=self._providers,
providers=ProviderSet(self._providers),
input_providers={
pcoll: python_provider
for pcoll in pcolls.values()
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_scope_by_spec(self, p, spec):
scope = Scope(
beam.pvalue.PBegin(p), {},
spec['transforms'],
yaml_provider.standard_providers(), {})
yaml_transform.ProviderSet(yaml_provider.standard_providers()), {})
return scope, spec

def test_get_pcollection_input(self):
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from apache_beam.yaml import YamlTransform
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_provider import InlineProvider
from apache_beam.yaml.yaml_transform import ProviderSet
from apache_beam.yaml.yaml_transform import SafeLineLoader
from apache_beam.yaml.yaml_transform import Scope
from apache_beam.yaml.yaml_transform import chain_as_composite
Expand Down Expand Up @@ -119,7 +120,7 @@ def get_scope_by_spec(self, p, spec, inputs=None):
beam.pvalue.PBegin(p),
inputs,
spec['transforms'],
yaml_provider.standard_providers(), {})
ProviderSet(yaml_provider.standard_providers()), {})
return scope, spec

def test_pipeline_as_composite_with_type_transforms(self):
Expand Down
Loading