diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 7cb96a7efb32..327023742bc6 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -956,6 +956,21 @@ def preprocess_languages(spec): else: return spec + def validate_transform_references(spec): + name = spec.get('name', '') + transform_type = spec.get('type') + inputs = spec.get('input').get('input', []) + + if not is_empty(inputs): + input_values = [inputs] if isinstance(inputs, str) else inputs + for input_value in input_values: + if input_value in (name, transform_type): + raise ValueError( + f"Circular reference detected: Transform {name} " + f"references itself as input in {identify_object(spec)}") + + return spec + for phase in [ ensure_transforms_have_types, normalize_mapping, @@ -966,6 +981,7 @@ def preprocess_languages(spec): preprocess_chain, tag_explicit_inputs, normalize_inputs_outputs, + validate_transform_references, preprocess_flattened_inputs, ensure_errors_consumed, preprocess_windowing, diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 7fcea7e2b662..b9caca4ca9f4 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -259,6 +259,51 @@ def test_csv_to_json(self): lines=True).sort_values('rank').reindex() pd.testing.assert_frame_equal(data, result) + def test_circular_reference_validation(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + # pylint: disable=expression-not-assigned + with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'): + p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create + config: + elements: [0, 1, 3, 4] + input: Create + - type: PyMap + name: PyMap + config: + fn: "lambda row: row.element * row.element" + input: Create + output: PyMap + ''', + providers=TEST_PROVIDERS) + + def test_circular_reference_multi_inputs_validation(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + # pylint: disable=expression-not-assigned + with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'): + p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create + config: + elements: [0, 1, 3, 4] + - type: PyMap + name: PyMap + config: + fn: "lambda row: row.element * row.element" + input: [Create, PyMap] + output: PyMap + ''', + providers=TEST_PROVIDERS) + def test_name_is_not_ambiguous(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: @@ -285,7 +330,7 @@ def test_name_is_ambiguous(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: # pylint: disable=expression-not-assigned - with self.assertRaisesRegex(ValueError, r'Ambiguous.*'): + with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'): p | YamlTransform( ''' type: composite