diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index a1cea2637cb5..5c76f9c228c8 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -876,11 +876,12 @@ def show( @progress_indicated def collect( - pcoll, + *pcolls, n='inf', duration='inf', include_window_info=False, - force_compute=False): + force_compute=False, + force_tuple=False): """Materializes the elements from a PCollection into a Dataframe. This reads each element from file and reads only the amount that it needs @@ -889,6 +890,7 @@ def collect( it is assumed to be infinite. Args: + pcolls: PCollections to compute. n: (optional) max number of elements to visualize. Default 'inf'. duration: (optional) max duration of elements to read in integer seconds or a string duration. Default 'inf'. @@ -896,6 +898,8 @@ def collect( to each row. Default False. force_compute: (optional) if True, forces recomputation rather than using cached PCollections + force_tuple: (optional) if True, return a 1-tuple or results rather than + the bare results if only one PCollection is computed For example:: @@ -906,17 +910,27 @@ def collect( # Run the pipeline and bring the PCollection into memory as a Dataframe. in_memory_square = head(square, n=5) """ - # Remember the element type so we can make an informed decision on how to - # collect the result in elements_to_df. - if isinstance(pcoll, DeferredBase): - # Get the proxy so we can get the output shape of the DataFrame. - pcoll, element_type = deferred_df_to_pcollection(pcoll) - watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll}) - else: - element_type = pcoll.element_type + if len(pcolls) == 0: + return () + + def as_pcollection(pcoll_or_df): + if isinstance(pcoll_or_df, DeferredBase): + # Get the proxy so we can get the output shape of the DataFrame. + pcoll, element_type = deferred_df_to_pcollection(pcoll_or_df) + watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll}) + return pcoll, element_type + elif isinstance(pcoll_or_df, beam.pvalue.PCollection): + return pcoll_or_df, pcoll_or_df.element_type + else: + raise TypeError(f'{pcoll} is not an apache_beam.pvalue.PCollection.') - assert isinstance(pcoll, beam.pvalue.PCollection), ( - '{} is not an apache_beam.pvalue.PCollection.'.format(pcoll)) + pcolls_with_element_types = [as_pcollection(p) for p in pcolls] + pcolls_to_element_types = dict(pcolls_with_element_types) + pcolls = [pcoll for pcoll, _ in pcolls_with_element_types] + pipelines = set(pcoll.pipeline for pcoll in pcolls) + if len(pipelines) != 1: + raise ValueError('All PCollections must belong to the same pipeline.') + pipeline, = pipelines if isinstance(n, str): assert n == 'inf', ( @@ -935,45 +949,51 @@ def collect( if duration == 'inf': duration = float('inf') - user_pipeline = ie.current_env().user_pipeline(pcoll.pipeline) + user_pipeline = ie.current_env().user_pipeline(pipeline) # Possibly collecting a PCollection defined in a local scope that is not # explicitly watched. Ad hoc watch it though it's a little late. if not user_pipeline: - watch({'anonymous_pipeline_{}'.format(id(pcoll.pipeline)): pcoll.pipeline}) - user_pipeline = pcoll.pipeline + watch({'anonymous_pipeline_{}'.format(id(pipeline)): pipeline}) + user_pipeline = pipeline recording_manager = ie.current_env().get_recording_manager( user_pipeline, create_if_absent=True) # If already computed, directly read the stream and return. - if pcoll in ie.current_env().computed_pcollections and not force_compute: - pcoll_name = find_pcoll_name(pcoll) - elements = list( - recording_manager.read(pcoll_name, pcoll, n, duration).read()) - return elements_to_df( - elements, - include_window_info=include_window_info, - element_type=element_type) - - recording = recording_manager.record([pcoll], - max_n=n, - max_duration=duration, - force_compute=force_compute) - - try: - elements = list(recording.stream(pcoll).read()) - except KeyboardInterrupt: - recording.cancel() - return pd.DataFrame() + computed = {} + for pcoll in pcolls_to_element_types.keys(): + if pcoll in ie.current_env().computed_pcollections and not force_compute: + pcoll_name = find_pcoll_name(pcoll) + computed[pcoll] = list( + recording_manager.read(pcoll_name, pcoll, n, duration).read()) + + uncomputed = set(pcolls) - set(computed.keys()) + if uncomputed: + recording = recording_manager.record( + uncomputed, max_n=n, max_duration=duration, force_compute=force_compute) + + try: + for pcoll in uncomputed: + computed[pcoll] = list(recording.stream(pcoll).read()) + except KeyboardInterrupt: + recording.cancel() if n == float('inf'): n = None # Collecting DataFrames may have a length > n, so slice again to be sure. Note # that array[:None] returns everything. - return elements_to_df( - elements, - include_window_info=include_window_info, - element_type=element_type)[:n] + empty = pd.DataFrame() + result_tuple = tuple( + elements_to_df( + computed[pcoll], + include_window_info=include_window_info, + element_type=pcolls_to_element_types[pcoll])[:n] if pcoll in + computed else empty for pcoll in pcolls) + + if len(result_tuple) == 1 and not force_tuple: + return result_tuple[0] + else: + return result_tuple @progress_indicated diff --git a/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py index 8f395aeeda5b..47adf7b36b33 100644 --- a/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py +++ b/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py @@ -100,6 +100,50 @@ def test_basic(self): self.assertEqual(set(collected2[0]), set(['A', 'B', 'C'])) self.assertEqual(count_side_effects('a'), 2) + @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") + def test_multiple_collect(self): + clear_side_effect() + p = beam.Pipeline(direct_runner.DirectRunner()) + + # Initial collection runs the pipeline. + pcollA = p | 'A' >> beam.Create(['a']) | 'As' >> beam.Map(cause_side_effect) + pcollB = p | 'B' >> beam.Create(['b']) | 'Bs' >> beam.Map(cause_side_effect) + collectedA, collectedB = ib.collect(pcollA, pcollB) + self.assertEqual(set(collectedA[0]), set(['a'])) + self.assertEqual(set(collectedB[0]), set(['b'])) + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(count_side_effects('b'), 1) + + # Collecting the PCollection again uses the cache. + collectedA, collectedB = ib.collect(pcollA, pcollB) + self.assertEqual(set(collectedA[0]), set(['a'])) + self.assertEqual(set(collectedB[0]), set(['b'])) + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(count_side_effects('b'), 1) + + # Using the PCollection uses the cache. + pcollAA = pcollA | beam.Map( + lambda x: 2 * x) | 'AAs' >> beam.Map(cause_side_effect) + collectedA, collectedB, collectedAA = ib.collect(pcollA, pcollB, pcollAA) + self.assertEqual(set(collectedA[0]), set(['a'])) + self.assertEqual(set(collectedB[0]), set(['b'])) + self.assertEqual(set(collectedAA[0]), set(['aa'])) + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(count_side_effects('b'), 1) + self.assertEqual(count_side_effects('aa'), 1) + + # Duplicates are only computed once. + pcollBB = pcollB | beam.Map( + lambda x: 2 * x) | 'BBs' >> beam.Map(cause_side_effect) + collectedAA, collectedAAagain, collectedBB, collectedBBagain = ib.collect( + pcollAA, pcollAA, pcollBB, pcollBB) + self.assertEqual(set(collectedAA[0]), set(['aa'])) + self.assertEqual(set(collectedAAagain[0]), set(['aa'])) + self.assertEqual(set(collectedBB[0]), set(['bb'])) + self.assertEqual(set(collectedBBagain[0]), set(['bb'])) + self.assertEqual(count_side_effects('aa'), 1) + self.assertEqual(count_side_effects('bb'), 1) + @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") def test_wordcount(self): class WordExtractingDoFn(beam.DoFn):