diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index d7415e8d8135..a3762adac0cb 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1462,8 +1462,14 @@ def partition_for(self, element, num_partitions, *args, **kwargs): def _get_function_body_without_inners(func): source_lines = inspect.getsourcelines(func)[0] source_lines = dropwhile(lambda x: x.startswith("@"), source_lines) - def_line = next(source_lines).strip() - if def_line.startswith("def ") and def_line.endswith(":"): + first_def_line = next(source_lines).strip() + if first_def_line.startswith("def "): + last_def_line_without_comment = first_def_line.split("#")[0] \ + .split("\"\"\"")[0] + while not last_def_line_without_comment.strip().endswith(":"): + last_def_line_without_comment = next(source_lines).split("#")[0] \ + .split("\"\"\"")[0] + first_line = next(source_lines) indentation = len(first_line) - len(first_line.lstrip()) final_lines = [first_line[indentation:]] @@ -1487,7 +1493,7 @@ def _get_function_body_without_inners(func): return "".join(final_lines) else: - return def_line.rsplit(":")[-1].strip() + return first_def_line.rsplit(":")[-1].strip() def _check_fn_use_yield_and_return(fn): @@ -1497,15 +1503,26 @@ def _check_fn_use_yield_and_return(fn): source_code = _get_function_body_without_inners(fn) has_yield = False has_return = False + return_none_warning = ( + "No iterator is returned by the process method in %s.", + fn.__self__.__class__) for line in source_code.split("\n"): - if line.lstrip().startswith("yield ") or line.lstrip().startswith( + lstripped_line = line.lstrip() + if lstripped_line.startswith("yield ") or lstripped_line.startswith( "yield("): has_yield = True - if line.lstrip().startswith("return ") or line.lstrip().startswith( + if lstripped_line.startswith("return ") or lstripped_line.startswith( "return("): has_return = True + if lstripped_line.startswith( + "return None") or lstripped_line.rstrip() == "return": + _LOGGER.warning(return_none_warning) if has_yield and has_return: return True + + if not has_yield and not has_return: + _LOGGER.warning(return_none_warning) + return False except Exception as e: _LOGGER.debug(str(e)) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index b492ab0938cc..54afb365d2d8 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -30,6 +30,8 @@ from apache_beam.testing.util import equal_to from apache_beam.transforms.window import FixedWindows +RETURN_NONE_PARTIAL_WARNING = "No iterator is returned" + class TestDoFn1(beam.DoFn): def process(self, element): @@ -96,6 +98,24 @@ def process(self, element): yield element +class TestDoFn10(beam.DoFn): + """test process returning None explicitly""" + def process(self, element): + return None + + +class TestDoFn11(beam.DoFn): + """test process returning None (no return and no yield)""" + def process(self, element): + pass + + +class TestDoFn12(beam.DoFn): + """test process returning None (return statement without a value)""" + def process(self, element): + return + + class CreateTest(unittest.TestCase): @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -119,6 +139,24 @@ def test_dofn_with_yield_and_return(self): beam.ParDo(TestDoFn3()) assert warning_text in self._caplog.text + def test_dofn_with_explicit_return_none(self): + with self._caplog.at_level(logging.WARNING): + beam.ParDo(TestDoFn10()) + assert RETURN_NONE_PARTIAL_WARNING in self._caplog.text + assert str(TestDoFn10) in self._caplog.text + + def test_dofn_with_implicit_return_none_missing_return_and_yield(self): + with self._caplog.at_level(logging.WARNING): + beam.ParDo(TestDoFn11()) + assert RETURN_NONE_PARTIAL_WARNING in self._caplog.text + assert str(TestDoFn11) in self._caplog.text + + def test_dofn_with_implicit_return_none_return_without_value(self): + with self._caplog.at_level(logging.WARNING): + beam.ParDo(TestDoFn12()) + assert RETURN_NONE_PARTIAL_WARNING in self._caplog.text + assert str(TestDoFn12) in self._caplog.text + class PartitionTest(unittest.TestCase): def test_partition_boundedness(self):