Skip to content

Commit

Permalink
Propagate correct boundedness when using multiple outputs (#29506)
Browse files Browse the repository at this point in the history
* Add failing test

* boundedness fix
  • Loading branch information
damccorm authored Nov 21, 2023
1 parent a349e76 commit 0972bc0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
7 changes: 6 additions & 1 deletion sdks/python/apache_beam/pvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,12 @@ def __getitem__(self, tag):
assert self.producer is not None
if tag is not None:
self._transform.output_tags.add(tag)
pcoll = PCollection(self._pipeline, tag=tag, element_type=typehints.Any)
is_bounded = all(i.is_bounded for i in self.producer.main_inputs.values())
pcoll = PCollection(
self._pipeline,
tag=tag,
element_type=typehints.Any,
is_bounded=is_bounded)
# Transfer the producer from the DoOutputsTuple to the resulting
# PCollection.
pcoll.producer = self.producer.parts[0]
Expand Down
28 changes: 28 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,34 @@ def test_dofn_with_yield_and_return(self):
assert warning_text in self._caplog.text


class PartitionTest(unittest.TestCase):
def test_partition_boundedness(self):
def partition_fn(val, num_partitions):
return val % num_partitions

class UnboundedDoFn(beam.DoFn):
@beam.DoFn.unbounded_per_element()
def process(self, element):
yield element

with beam.testing.test_pipeline.TestPipeline() as p:
source = p | beam.Create([1, 2, 3, 4, 5])
p1, p2, p3 = source | "bounded" >> beam.Partition(partition_fn, 3)

self.assertEqual(source.is_bounded, True)
self.assertEqual(p1.is_bounded, True)
self.assertEqual(p2.is_bounded, True)
self.assertEqual(p3.is_bounded, True)

unbounded = source | beam.ParDo(UnboundedDoFn())
p4, p5, p6 = unbounded | "unbounded" >> beam.Partition(partition_fn, 3)

self.assertEqual(unbounded.is_bounded, False)
self.assertEqual(p4.is_bounded, False)
self.assertEqual(p5.is_bounded, False)
self.assertEqual(p6.is_bounded, False)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit 0972bc0

Please sign in to comment.