Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better errors when inputs are omitted. #28289

Merged
merged 6 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def provided_transforms(self) -> Iterable[str]:
"""Returns a list of transform type names this provider can handle."""
raise NotImplementedError(type(self))

def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool:
"""Returns whether this transform requires inputs.

Specifically, if this returns True and inputs are not provided than an error
will be thrown.

This is best-effort, primarily for better and earlier error messages.
"""
return not typ.startswith('Read')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presuming this is to provide a reasonable default? Seems like it might lead to surprising false positives though?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, just a reasonable default. Most transforms that do not start with Read require an input by, the main exception being SQL (where we can't tell and the "typical" use of selecting on PCollection would be good to error on). One can explicitly state inputs: [] to suppress this error.


def create_transform(
self,
typ: str,
Expand Down Expand Up @@ -125,9 +135,7 @@ def __init__(self, urns, service):
def provided_transforms(self):
return self._urns.keys()

def create_transform(self, type, args, yaml_create_transform):
if callable(self._service):
self._service = self._service()
def schema_transforms(self):
if self._schema_transforms is None:
try:
self._schema_transforms = {
Expand All @@ -138,8 +146,19 @@ def create_transform(self, type, args, yaml_create_transform):
except Exception:
# It's possible this service doesn't vend schema transforms.
self._schema_transforms = {}
return self._schema_transforms

def requires_inputs(self, typ, args):
if self._urns[type] in self.schema_transforms():
return bool(self.schema_transforms()[self._urns[type]].inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately relying on schema transform's inputs/outputs is not very reliable (and maybe completely unreliable). They do not depend on the configuration passed to the transforms making them basically useless for determining anything other than inputs/outputs that are present in all possible configurations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it'd be nice to be able to query about the requirements of a configured transform before invoking it, but that's a larger change... Most (though not all) transforms have fixed inputs at least, so this'll cover those.

else:
return super().requires_inputs(typ, args)

def create_transform(self, type, args, yaml_create_transform):
if callable(self._service):
self._service = self._service()
urn = self._urns[type]
if urn in self._schema_transforms:
if urn in self.schema_transforms():
return external.SchemaAwareExternalTransform(
urn, self._service, rearrange_based_on_discovery=True, **args)
else:
Expand Down Expand Up @@ -345,8 +364,9 @@ def fn_takes_side_inputs(fn):


class InlineProvider(Provider):
def __init__(self, transform_factories):
def __init__(self, transform_factories, no_input_transforms=()):
self._transform_factories = transform_factories
self._no_input_transforms = set(no_input_transforms)

def available(self):
return True
Expand All @@ -360,6 +380,14 @@ def create_transform(self, type, args, yaml_create_transform):
def to_json(self):
return {'type': "InlineProvider"}

def requires_inputs(self, typ, args):
if typ in self._no_input_transforms:
return False
elif hasattr(self._transform_factories[typ], '_yaml_requires_inputs'):
return self._transform_factories[typ]._yaml_requires_inputs
else:
return super().requires_inputs(typ, args)


class MetaInlineProvider(InlineProvider):
def create_transform(self, type, args, yaml_create_transform):
Expand Down Expand Up @@ -491,30 +519,30 @@ def _parse_window_spec(spec):
# TODO: Triggering, etc.
return beam.WindowInto(window_fn)

return InlineProvider(
dict({
'Create': create,
'PyMap': lambda fn: beam.Map(
python_callable.PythonCallableWithSource(fn)),
'PyMapTuple': lambda fn: beam.MapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMap': lambda fn: beam.FlatMap(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFilter': lambda keep: beam.Filter(
python_callable.PythonCallableWithSource(keep)),
'PyTransform': fully_qualified_named_transform,
'PyToRow': lambda fields: beam.Select(
**{
name: python_callable.PythonCallableWithSource(fn)
for (name, fn) in fields.items()
}),
'WithSchema': with_schema,
'Flatten': Flatten,
'WindowInto': WindowInto,
'GroupByKey': beam.GroupByKey,
}))
return InlineProvider({
'Create': create,
'PyMap': lambda fn: beam.Map(
python_callable.PythonCallableWithSource(fn)),
'PyMapTuple': lambda fn: beam.MapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMap': lambda fn: beam.FlatMap(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFilter': lambda keep: beam.Filter(
python_callable.PythonCallableWithSource(keep)),
'PyTransform': fully_qualified_named_transform,
'PyToRow': lambda fields: beam.Select(
**{
name: python_callable.PythonCallableWithSource(fn)
for (name, fn) in fields.items()
}),
'WithSchema': with_schema,
'Flatten': Flatten,
'WindowInto': WindowInto,
'GroupByKey': beam.GroupByKey,
},
no_input_transforms=('Create', ))


class PypiExpansionService:
Expand Down Expand Up @@ -585,6 +613,9 @@ def available(self) -> bool:
def provided_transforms(self) -> Iterable[str]:
return self._transforms.keys()

def requires_inputs(self, typ, args):
return self._underlying_provider.requires_inputs(typ, args)

def create_transform(
self,
typ: str,
Expand Down
84 changes: 67 additions & 17 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ def only_element(xs):
return x


# These allow a user to explicitly pass no input to a transform (i.e. use it
# as a root transform) without an error even if the transform is not known to
# handle it.
def explicitly_empty():
return {'__explicitly_empty__': None}


def is_explicitly_empty(io):
return io == explicitly_empty()


def is_empty(io):
return not io or is_explicitly_empty(io)


def empty_if_explicitly_empty(io):
if is_explicitly_empty(io):
return {}
else:
return io


class SafeLineLoader(SafeLoader):
"""A yaml loader that attaches line information to mappings and strings."""
class TaggedString(str):
Expand Down Expand Up @@ -186,7 +208,7 @@ def followers(self, transform_name):
# TODO(yaml): Also trace through outputs and composites.
for transform in self._transforms:
if transform['type'] != 'composite':
for input in transform.get('input').values():
for input in empty_if_explicitly_empty(transform['input']).values():
transform_id, _ = self.get_transform_id_and_output_name(input)
self._all_followers[transform_id].append(transform['__uuid__'])
return self._all_followers[self.get_transform_id(transform_name)]
Expand Down Expand Up @@ -324,6 +346,12 @@ def create_ptransform(self, spec, input_pcolls):
raise ValueError(
'Config for transform at %s must be a mapping.' %
identify_object(spec))

if (not input_pcolls and not is_explicitly_empty(spec.get('input', {})) and
provider.requires_inputs(spec['type'], config)):
raise ValueError(
f'Missing inputs for transform at {identify_object(spec)}')

try:
# pylint: disable=undefined-loop-variable
ptransform = provider.create_transform(
Expand Down Expand Up @@ -402,7 +430,7 @@ def expand_leaf_transform(spec, scope):
spec = normalize_inputs_outputs(spec)
inputs_dict = {
key: scope.get_pcollection(value)
for (key, value) in spec['input'].items()
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
}
input_type = spec.get('input_type', 'default')
if input_type == 'list':
Expand Down Expand Up @@ -442,10 +470,10 @@ def expand_composite_transform(spec, scope):
spec = normalize_inputs_outputs(normalize_source_sink(spec))

inner_scope = Scope(
scope.root, {
scope.root,
{
key: scope.get_pcollection(value)
for key,
value in spec['input'].items()
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
},
spec['transforms'],
yaml_provider.merge_providers(
Expand All @@ -470,8 +498,7 @@ def expand(inputs):
_LOGGER.info("Expanding %s ", identify_object(spec))
return ({
key: scope.get_pcollection(value)
for key,
value in spec['input'].items()
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
} or scope.root) | scope.unique_name(spec, None) >> CompositePTransform()


Expand All @@ -496,12 +523,25 @@ def is_not_output_of_last_transform(new_transforms, value):
composite_spec = normalize_inputs_outputs(spec)
new_transforms = []
for ix, transform in enumerate(composite_spec['transforms']):
if any(io in transform for io in ('input', 'output', 'input', 'output')):
raise ValueError(
f'Transform {identify_object(transform)} is part of a chain, '
'must have implicit inputs and outputs.')
if any(io in transform for io in ('input', 'output')):
if (ix == 0 and 'input' in transform and 'output' not in transform and
is_explicitly_empty(transform['input'])):
# This is OK as source clause sets an explicitly empty input.
pass
else:
raise ValueError(
f'Transform {identify_object(transform)} is part of a chain, '
'must have implicit inputs and outputs.')
if ix == 0:
transform['input'] = {key: key for key in composite_spec['input'].keys()}
if is_explicitly_empty(transform.get('input', None)):
pass
elif is_explicitly_empty(composite_spec['input']):
transform['input'] = composite_spec['input']
else:
transform['input'] = {
key: key
for key in composite_spec['input'].keys()
}
else:
transform['input'] = new_transforms[-1]['__uuid__']
new_transforms.append(transform)
Expand Down Expand Up @@ -554,6 +594,8 @@ def normalize_source_sink(spec):
spec = dict(spec)
spec['transforms'] = list(spec.get('transforms', []))
if 'source' in spec:
if 'input' not in spec['source']:
spec['source']['input'] = explicitly_empty()
spec['transforms'].insert(0, spec.pop('source'))
if 'sink' in spec:
spec['transforms'].append(spec.pop('sink'))
Expand All @@ -567,6 +609,13 @@ def preprocess_source_sink(spec):
return spec


def tag_explicit_inputs(spec):
if 'input' in spec and not SafeLineLoader.strip_metadata(spec['input']):
return dict(spec, input=explicitly_empty())
else:
return spec


def normalize_inputs_outputs(spec):
spec = dict(spec)

Expand Down Expand Up @@ -611,7 +660,7 @@ def push_windowing_to_roots(spec):
scope = LightweightScope(spec['transforms'])
consumed_outputs_by_transform = collections.defaultdict(set)
for transform in spec['transforms']:
for _, input_ref in transform['input'].items():
for _, input_ref in empty_if_explicitly_empty(transform['input']).items():
try:
transform_id, output = scope.get_transform_id_and_output_name(input_ref)
consumed_outputs_by_transform[transform_id].add(output)
Expand All @@ -620,7 +669,7 @@ def push_windowing_to_roots(spec):
pass

for transform in spec['transforms']:
if not transform['input'] and 'windowing' not in transform:
if is_empty(transform['input']) and 'windowing' not in transform:
transform['windowing'] = spec['windowing']
transform['__consumed_outputs'] = consumed_outputs_by_transform[
transform['__uuid__']]
Expand All @@ -647,7 +696,7 @@ def preprocess_windowing(spec):
spec = push_windowing_to_roots(spec)

windowing = spec.pop('windowing')
if spec['input']:
if not is_empty(spec['input']):
# Apply the windowing to all inputs by wrapping it in a transform that
# first applies windowing and then applies the original transform.
original_inputs = spec['input']
Expand Down Expand Up @@ -778,7 +827,7 @@ def ensure_errors_consumed(spec):
raise ValueError(
f'Missing output in error_handling of {identify_object(t)}')
to_handle[t['__uuid__'], config['error_handling']['output']] = t
for _, input in t['input'].items():
for _, input in empty_if_explicitly_empty(t['input']).items():
if input not in spec['input']:
consumed.add(scope.get_transform_id_and_output_name(input))
for error_pcoll, t in to_handle.items():
Expand Down Expand Up @@ -815,7 +864,7 @@ def preprocess(spec, verbose=False, known_transforms=None):

def apply(phase, spec):
spec = phase(spec)
if spec['type'] in {'composite', 'chain'}:
if spec['type'] in {'composite', 'chain'} and 'transforms' in spec:
spec = dict(
spec, transforms=[apply(phase, t) for t in spec['transforms']])
return spec
Expand All @@ -835,6 +884,7 @@ def ensure_transforms_have_providers(spec):
ensure_transforms_have_providers,
preprocess_source_sink,
preprocess_chain,
tag_explicit_inputs,
normalize_inputs_outputs,
preprocess_flattened_inputs,
ensure_errors_consumed,
Expand Down
31 changes: 2 additions & 29 deletions sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,40 +88,13 @@ def test_create_ptransform(self):
spec = '''
transforms:
- type: PyMap
input: something
config:
fn: "lambda x: x*x"
'''
scope, spec = self.get_scope_by_spec(p, spec)

result = scope.create_ptransform(spec['transforms'][0], [])
self.assertIsInstance(result, beam.transforms.ParDo)
self.assertEqual(result.label, 'Map(lambda x: x*x)')

result_annotations = {**result.annotations()}
target_annotations = {
'yaml_type': 'PyMap',
'yaml_args': '{"fn": "lambda x: x*x"}',
'yaml_provider': '{"type": "InlineProvider"}'
}

# Check if target_annotations is a subset of result_annotations
self.assertDictEqual(
result_annotations, {
**result_annotations, **target_annotations
})

def test_create_ptransform_with_inputs(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
spec = '''
transforms:
- type: PyMap
config:
fn: "lambda x: x*x"
'''
scope, spec = self.get_scope_by_spec(p, spec)

result = scope.create_ptransform(spec['transforms'][0], [])
result = scope.create_ptransform(spec['transforms'][0], ['something'])
self.assertIsInstance(result, beam.transforms.ParDo)
self.assertEqual(result.label, 'Map(lambda x: x*x)')

Expand Down
Loading