Skip to content

Commit

Permalink
Remove special PyMap, PyFlatMap, etc. (#28546)
Browse files Browse the repository at this point in the history
Prefer to use generic MapToFields instead.
  • Loading branch information
robertwb authored Sep 25, 2023
1 parent 362a887 commit dd9d8d0
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 172 deletions.
44 changes: 26 additions & 18 deletions sdks/python/apache_beam/yaml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,15 @@ pipeline:
- type: ReadFromCsv
config:
path: /path/to/input*.csv
- type: PyFilter
- type: Filter
config:
keep: "lambda x: x.col3 > 100"
language: python
keep: "col3 > 100"
input: ReadFromCsv
- type: WriteToJson
config:
path: /path/to/output.json
input: PyFilter
input: Filter
```

or two.
Expand All @@ -90,15 +91,16 @@ pipeline:
- type: ReadFromCsv
config:
path: /path/to/input*.csv
- type: PyFilter
- type: Filter
config:
keep: "lambda x: x.col3 > 100"
language: python
keep: "col3 > 100"
input: ReadFromCsv
- type: Sql
name: MySqlTransform
config:
query: "select col1, count(*) as cnt from PCOLLECTION group by col1"
input: PyFilter
input: Filter
- type: WriteToJson
config:
path: /path/to/output.json
Expand All @@ -116,9 +118,10 @@ pipeline:
- type: ReadFromCsv
config:
path: /path/to/input*.csv
- type: PyFilter
- type: Filter
config:
keep: "lambda x: x.col3 > 100"
language: python
keep: "col3 > 100"
- type: Sql
name: MySqlTransform
config:
Expand All @@ -141,9 +144,10 @@ pipeline:
path: /path/to/input*.csv
transforms:
- type: PyFilter
- type: Filter
config:
keep: "lambda x: x.col3 > 100"
language: python
keep: "col3 > 100"
- type: Sql
name: MySqlTransform
Expand Down Expand Up @@ -185,11 +189,12 @@ pipeline:
config:
path: /path/to/all.json
- type: PyFilter
- type: Filter
name: FilterToBig
input: Sql
config:
keep: "lambda x: x.col2 > 100"
language: python
keep: "col2 > 100"
- type: WriteToCsv
name: WriteBig
Expand Down Expand Up @@ -231,15 +236,18 @@ pipeline:
name: ExtraProcessingForBigRows
input: Sql
transforms:
- type: PyFilter
- type: Filter
config:
keep: "lambda x: x.col2 > 100"
- type: PyFilter
language: python
keep: "col2 > 100"
- type: Filter
config:
keep: "lambda x: len(x.col1) > 10"
- type: PyFilter
language: python
keep: "len(col1) > 10"
- type: Filter
config:
keep: "lambda x: x.col1 > 'z'"
language: python
keep: "col1 > 'z'"
sink:
type: WriteToCsv
config:
Expand Down
9 changes: 7 additions & 2 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,13 @@ def with_exception_handling(self, **kwargs):
@maybe_with_exception_handling_transform_fn
def _PyJsFilter(
pcoll, keep: Union[str, Dict[str, str]], language: Optional[str] = None):
try:
input_schema = dict(named_fields_from_element_type(pcoll.element_type))
except (TypeError, ValueError) as exn:
if is_expr(keep):
raise ValueError("Can only use expressions on a schema'd input.") from exn
input_schema = {} # unused

input_schema = dict(named_fields_from_element_type(pcoll.element_type))
if isinstance(keep, str) and keep in input_schema:
keep_fn = lambda row: getattr(row, keep)
else:
Expand All @@ -273,7 +278,7 @@ def is_expr(v):
def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'):
try:
input_schema = dict(named_fields_from_element_type(pcoll.element_type))
except ValueError as exn:
except (TypeError, ValueError) as exn:
if drop:
raise ValueError("Can only drop fields on a schema'd input.") from exn
if append:
Expand Down
16 changes: 6 additions & 10 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import hashlib
import inspect
import json
import logging
import os
import subprocess
import sys
Expand Down Expand Up @@ -580,18 +581,13 @@ def _parse_window_spec(spec):
# TODO: Triggering, etc.
return beam.WindowInto(window_fn)

def log_and_return(x):
logging.info(x)
return x

return InlineProvider({
'Create': create,
'PyMap': lambda fn: beam.Map(
python_callable.PythonCallableWithSource(fn)),
'PyMapTuple': lambda fn: beam.MapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMap': lambda fn: beam.FlatMap(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFilter': lambda keep: beam.Filter(
python_callable.PythonCallableWithSource(keep)),
'LogForTesting': lambda: beam.Map(log_and_return),
'PyTransform': fully_qualified_named_transform,
'WithSchemaExperimental': with_schema,
'Flatten': Flatten,
Expand Down
14 changes: 8 additions & 6 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,8 @@ def ensure_transforms_have_providers(spec):
if known_transforms:
if spec['type'] not in known_transforms:
raise ValueError(
f'Unknown type or missing provider for {identify_object(spec)}')
'Unknown type or missing provider '
f'for type {spec["type"]} for {identify_object(spec)}')
return spec

def preprocess_langauges(spec):
Expand Down Expand Up @@ -919,13 +920,14 @@ class YamlTransform(beam.PTransform):
def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-value
if isinstance(spec, str):
spec = yaml.load(spec, Loader=SafeLineLoader)
if isinstance(providers, dict):
providers = {
key: yaml_provider.as_provider_list(key, value)
for (key, value) in providers.items()
}
# TODO(BEAM-26941): Validate as a transform.
self._providers = yaml_provider.merge_providers(
{
key: yaml_provider.as_provider_list(key, value)
for (key, value) in providers.items()
},
yaml_provider.standard_providers())
providers, yaml_provider.standard_providers())
self._spec = preprocess(spec, known_transforms=self._providers.keys())

def expand(self, pcolls):
Expand Down
37 changes: 18 additions & 19 deletions sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ def test_get_pcollection_output(self):
- type: Create
config:
elements: [0, 1, 3, 4]
- type: PyMap
- type: LogForTesting
name: Square
input: Create
config:
fn: "lambda x: x*x"
'''

scope, spec = self.get_scope_by_spec(p, spec)
Expand All @@ -77,38 +75,39 @@ def test_get_pcollection_output(self):
"PCollection[Square.None]", str(scope.get_pcollection("Square")))

self.assertEqual(
"PCollection[Square.None]", str(scope.get_pcollection("PyMap")))
"PCollection[Square.None]", str(scope.get_pcollection("LogForTesting")))

self.assertTrue(
scope.get_pcollection("Square") == scope.get_pcollection("PyMap"))
scope.get_pcollection("Square") == scope.get_pcollection(
"LogForTesting"))

def test_create_ptransform(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
spec = '''
transforms:
- type: PyMap
input: something
- type: Create
config:
fn: "lambda x: x*x"
elements: [1, 2, 3]
'''
scope, spec = self.get_scope_by_spec(p, spec)

result = scope.create_ptransform(spec['transforms'][0], ['something'])
self.assertIsInstance(result, beam.transforms.ParDo)
self.assertEqual(result.label, 'Map(lambda x: x*x)')
result = scope.create_ptransform(spec['transforms'][0], [])
self.assertIsInstance(result, beam.transforms.Create)
self.assertEqual(result.label, 'Create')

result_annotations = {
key: value
for (key, value) in result.annotations().items()
if key.startswith('yaml')
}
result_annotations = {**result.annotations()}
target_annotations = {
'yaml_type': 'PyMap',
'yaml_args': '{"fn": "lambda x: x*x"}',
'yaml_type': 'Create',
'yaml_args': '{"elements": [1, 2, 3]}',
'yaml_provider': '{"type": "InlineProvider"}'
}
self.assertDictEqual(result_annotations, target_annotations)

# Check if target_annotations is a subset of result_annotations
self.assertDictEqual(
result_annotations, {
**result_annotations, **target_annotations
})


class TestProvider(yaml_provider.InlineProvider):
Expand Down
Loading

0 comments on commit dd9d8d0

Please sign in to comment.