Skip to content

Commit

Permalink
Validate circular reference for yaml (apache#33208)
Browse files Browse the repository at this point in the history
  • Loading branch information
mravi authored Dec 18, 2024
1 parent 286e29c commit e68a79c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,21 @@ def preprocess_languages(spec):
else:
return spec

def validate_transform_references(spec):
name = spec.get('name', '')
transform_type = spec.get('type')
inputs = spec.get('input').get('input', [])

if not is_empty(inputs):
input_values = [inputs] if isinstance(inputs, str) else inputs
for input_value in input_values:
if input_value in (name, transform_type):
raise ValueError(
f"Circular reference detected: Transform {name} "
f"references itself as input in {identify_object(spec)}")

return spec

for phase in [
ensure_transforms_have_types,
normalize_mapping,
Expand All @@ -966,6 +981,7 @@ def preprocess_languages(spec):
preprocess_chain,
tag_explicit_inputs,
normalize_inputs_outputs,
validate_transform_references,
preprocess_flattened_inputs,
ensure_errors_consumed,
preprocess_windowing,
Expand Down
47 changes: 46 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,51 @@ def test_csv_to_json(self):
lines=True).sort_values('rank').reindex()
pd.testing.assert_frame_equal(data, result)

def test_circular_reference_validation(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
# pylint: disable=expression-not-assigned
with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create
config:
elements: [0, 1, 3, 4]
input: Create
- type: PyMap
name: PyMap
config:
fn: "lambda row: row.element * row.element"
input: Create
output: PyMap
''',
providers=TEST_PROVIDERS)

def test_circular_reference_multi_inputs_validation(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
# pylint: disable=expression-not-assigned
with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create
config:
elements: [0, 1, 3, 4]
- type: PyMap
name: PyMap
config:
fn: "lambda row: row.element * row.element"
input: [Create, PyMap]
output: PyMap
''',
providers=TEST_PROVIDERS)

def test_name_is_not_ambiguous(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
Expand All @@ -285,7 +330,7 @@ def test_name_is_ambiguous(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
# pylint: disable=expression-not-assigned
with self.assertRaisesRegex(ValueError, r'Ambiguous.*'):
with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
p | YamlTransform(
'''
type: composite
Expand Down

0 comments on commit e68a79c

Please sign in to comment.