Skip to content

Commit

Permalink
Merge pull request #29093 [YAML] Don't require redundant input for Ya…
Browse files Browse the repository at this point in the history
…mlTransform.
  • Loading branch information
robertwb authored Oct 20, 2023
1 parent 97a52af commit abce1ad
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
6 changes: 0 additions & 6 deletions sdks/python/apache_beam/yaml/yaml_io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def test_simple_write(self):
| YamlTransform(
'''
type: WriteToPubSub
input: input
config:
topic: my_topic
format: raw
Expand All @@ -341,7 +340,6 @@ def test_write_with_attribute(self):
]) | YamlTransform(
'''
type: WriteToPubSub
input: input
config:
topic: my_topic
format: raw
Expand All @@ -364,7 +362,6 @@ def test_write_with_attribute_map(self):
]) | YamlTransform(
'''
type: WriteToPubSub
input: input
config:
topic: my_topic
format: raw
Expand All @@ -384,7 +381,6 @@ def test_write_with_id_attribute(self):
| YamlTransform(
'''
type: WriteToPubSub
input: input
config:
topic: my_topic
format: raw
Expand All @@ -408,7 +404,6 @@ def test_write_avro(self):
| YamlTransform(
'''
type: WriteToPubSub
input: input
config:
topic: my_topic
format: avro
Expand All @@ -434,7 +429,6 @@ def test_write_json(self):
]) | YamlTransform(
'''
type: WriteToPubSub
input: input
config:
topic: my_topic
format: json
Expand Down
4 changes: 0 additions & 4 deletions sdks/python/apache_beam/yaml/yaml_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_basic(self):
result = elements | YamlTransform(
'''
type: MapToFields
input: input
config:
language: python
fields:
Expand All @@ -62,7 +61,6 @@ def test_drop(self):
result = elements | YamlTransform(
'''
type: MapToFields
input: input
config:
fields: {}
append: true
Expand All @@ -83,7 +81,6 @@ def test_filter(self):
result = elements | YamlTransform(
'''
type: Filter
input: input
config:
language: python
keep: "rank > 0"
Expand All @@ -106,7 +103,6 @@ def test_explode(self):
result = elements | YamlTransform(
'''
type: chain
input: input
transforms:
- type: MapToFields
config:
Expand Down
12 changes: 11 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def is_not_output_of_last_transform(new_transforms, value):
raise TypeError(
f"Chain at {identify_object(spec)} missing transforms property.")
has_explicit_outputs = 'output' in spec
composite_spec = normalize_inputs_outputs(spec)
composite_spec = normalize_inputs_outputs(tag_explicit_inputs(spec))
new_transforms = []
for ix, transform in enumerate(composite_spec['transforms']):
if any(io in transform for io in ('input', 'output')):
Expand All @@ -539,6 +539,8 @@ def is_not_output_of_last_transform(new_transforms, value):
pass
elif is_explicitly_empty(composite_spec['input']):
transform['input'] = composite_spec['input']
elif is_empty(composite_spec['input']):
del composite_spec['input']
else:
transform['input'] = {
key: key
Expand Down Expand Up @@ -931,6 +933,7 @@ def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-val
self._providers = yaml_provider.merge_providers(
providers, yaml_provider.standard_providers())
self._spec = preprocess(spec, known_transforms=self._providers.keys())
self._was_chain = spec['type'] == 'chain'

def expand(self, pcolls):
if isinstance(pcolls, beam.pvalue.PBegin):
Expand All @@ -939,8 +942,15 @@ def expand(self, pcolls):
elif isinstance(pcolls, beam.PCollection):
root = pcolls.pipeline
pcolls = {'input': pcolls}
if not self._spec['input']:
self._spec['input'] = {'input': 'input'}
if self._was_chain and self._spec['transforms']:
# This should have been copied as part of the composite-to-chain.
self._spec['transforms'][0]['input'] = self._spec['input']
else:
root = next(iter(pcolls.values())).pipeline
if not self._spec['input']:
self._spec['input'] = {name: name for name in pcolls.keys()}
result = expand_transform(
self._spec,
Scope(
Expand Down
12 changes: 11 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def expand(self, p):
| beam.Map(lambda x: beam.transforms.window.TimestampedValue(x, x)))


class CreateInts(beam.PTransform):
_yaml_requires_inputs = False

def __init__(self, elements):
self._elements = elements

def expand(self, p):
return p | beam.Create(self._elements)


class SumGlobally(beam.PTransform):
def expand(self, pcoll):
return pcoll | beam.CombineGlobally(sum).without_defaults()
Expand All @@ -65,7 +75,7 @@ def raise_on_big(element):


TEST_PROVIDERS = {
'CreateInts': lambda elements: beam.Create(elements),
'CreateInts': CreateInts,
'CreateTimestamped': CreateTimestamped,
'SumGlobally': SumGlobally,
'SizeLimiter': SizeLimiter,
Expand Down
2 changes: 0 additions & 2 deletions sdks/python/apache_beam/yaml/yaml_transform_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,10 @@ def test_chain_as_composite(self):
expected = f'''
type: composite
name: Chain
input: {{}}
transforms:
- type: Create
config:
elements: [0,1,2]
input: {{}}
- type: PyMap
config:
fn: 'lambda x: x*x'
Expand Down
8 changes: 0 additions & 8 deletions sdks/python/apache_beam/yaml/yaml_udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test_map_to_fields_filter_inline_js(self):
result = elements | YamlTransform(
'''
type: MapToFields
input: input
config:
language: javascript
fields:
Expand All @@ -79,7 +78,6 @@ def test_map_to_fields_filter_inline_py(self):
result = elements | YamlTransform(
'''
type: MapToFields
input: input
config:
language: python
fields:
Expand All @@ -103,7 +101,6 @@ def test_filter_inline_js(self):
result = elements | YamlTransform(
'''
type: Filter
input: input
config:
language: javascript
keep:
Expand All @@ -123,7 +120,6 @@ def test_filter_inline_py(self):
result = elements | YamlTransform(
'''
type: Filter
input: input
config:
language: python
keep:
Expand All @@ -143,7 +139,6 @@ def test_filter_expression_js(self):
result = elements | YamlTransform(
'''
type: Filter
input: input
config:
language: javascript
keep:
Expand All @@ -162,7 +157,6 @@ def test_filter_expression_py(self):
result = elements | YamlTransform(
'''
type: Filter
input: input
config:
language: python
keep:
Expand Down Expand Up @@ -194,7 +188,6 @@ def test_filter_inline_js_file(self):
result = elements | YamlTransform(
f'''
type: Filter
input: input
config:
language: javascript
keep:
Expand Down Expand Up @@ -226,7 +219,6 @@ def g(x):
result = elements | YamlTransform(
f'''
type: Filter
input: input
config:
language: python
keep:
Expand Down

0 comments on commit abce1ad

Please sign in to comment.