diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 90882651d0b2..b783d61f95c9 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -304,7 +304,12 @@ def __getitem__(self, tag): assert self.producer is not None if tag is not None: self._transform.output_tags.add(tag) - pcoll = PCollection(self._pipeline, tag=tag, element_type=typehints.Any) + is_bounded = all(i.is_bounded for i in self.producer.main_inputs.values()) + pcoll = PCollection( + self._pipeline, + tag=tag, + element_type=typehints.Any, + is_bounded=is_bounded) # Transfer the producer from the DoOutputsTuple to the resulting # PCollection. pcoll.producer = self.producer.parts[0] diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 0fba28266138..a60974ceb706 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -108,6 +108,34 @@ def test_dofn_with_yield_and_return(self): assert warning_text in self._caplog.text +class PartitionTest(unittest.TestCase): + def test_partition_boundedness(self): + def partition_fn(val, num_partitions): + return val % num_partitions + + class UnboundedDoFn(beam.DoFn): + @beam.DoFn.unbounded_per_element() + def process(self, element): + yield element + + with beam.testing.test_pipeline.TestPipeline() as p: + source = p | beam.Create([1, 2, 3, 4, 5]) + p1, p2, p3 = source | "bounded" >> beam.Partition(partition_fn, 3) + + self.assertEqual(source.is_bounded, True) + self.assertEqual(p1.is_bounded, True) + self.assertEqual(p2.is_bounded, True) + self.assertEqual(p3.is_bounded, True) + + unbounded = source | beam.ParDo(UnboundedDoFn()) + p4, p5, p6 = unbounded | "unbounded" >> beam.Partition(partition_fn, 3) + + self.assertEqual(unbounded.is_bounded, False) + self.assertEqual(p4.is_bounded, False) + self.assertEqual(p5.is_bounded, False) + self.assertEqual(p6.is_bounded, False) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()