From ba6343bfe3acdf295de57caa8572f32b8d731cd3 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 30 Sep 2024 12:43:17 -0400 Subject: [PATCH] Force BQIO to output elements in the correct row (#32584) * Fix bqio * import fix * syntax * feedback --- .../trigger_files/beam_PostCommit_Python.json | 2 +- ...it_Python_ValidatesContainer_Dataflow.json | 3 +- sdks/python/apache_beam/io/gcp/bigquery.py | 59 +++++++++++-------- .../io/gcp/bigquery_write_it_test.py | 1 + 4 files changed, 38 insertions(+), 27 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 30ee463ad4e9..1eb60f6e4959 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 2 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesContainer_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesContainer_Dataflow.json index d6c608f6daba..4897480d69ad 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesContainer_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesContainer_Dataflow.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 } \ No newline at end of file diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index b897df2d32ab..2cb64742f26c 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -418,7 +418,6 @@ def chain_after(result): from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX from apache_beam.transforms.sideinputs import get_sideinput_index from apache_beam.transforms.util import ReshufflePerKey -from apache_beam.transforms.window import GlobalWindows from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import schema_from_element_type from apache_beam.utils import retry @@ -1581,7 +1580,8 @@ def _create_table_if_needed(self, table_reference, schema=None): additional_create_parameters=self.additional_bq_parameters) _KNOWN_TABLES.add(str_table_reference) - def process(self, element, *schema_side_inputs): + def process( + self, element, window_value=DoFn.WindowedValueParam, *schema_side_inputs): destination = bigquery_tools.get_hashable_destination(element[0]) if callable(self.schema): @@ -1608,12 +1608,11 @@ def process(self, element, *schema_side_inputs): return [ pvalue.TaggedOutput( BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS, - GlobalWindows.windowed_value( + window_value.with_value( (destination, row_and_insert_id[0], error))), pvalue.TaggedOutput( BigQueryWriteFn.FAILED_ROWS, - GlobalWindows.windowed_value( - (destination, row_and_insert_id[0]))) + window_value.with_value((destination, row_and_insert_id[0]))) ] # Flush current batch first if adding this row will exceed our limits @@ -1624,11 +1623,11 @@ def process(self, element, *schema_side_inputs): flushed_batch = self._flush_batch(destination) # After flushing our existing batch, we now buffer the current row # for the next flush - self._rows_buffer[destination].append(row_and_insert_id) + self._rows_buffer[destination].append((row_and_insert_id, window_value)) self._destination_buffer_byte_size[destination] = row_byte_size return flushed_batch - self._rows_buffer[destination].append(row_and_insert_id) + self._rows_buffer[destination].append((row_and_insert_id, window_value)) self._destination_buffer_byte_size[destination] += row_byte_size self._total_buffered_rows += 1 if self._total_buffered_rows >= self._max_buffered_rows: @@ -1636,7 +1635,8 @@ def process(self, element, *schema_side_inputs): else: # The input is already batched per destination, flush the rows now. batched_rows = element[1] - self._rows_buffer[destination].extend(batched_rows) + for r in batched_rows: + self._rows_buffer[destination].append((r, window_value)) return self._flush_batch(destination) def finish_bundle(self): @@ -1659,7 +1659,7 @@ def _flush_all_batches(self): def _flush_batch(self, destination): # Flush the current batch of rows to BigQuery. - rows_and_insert_ids = self._rows_buffer[destination] + rows_and_insert_ids_with_windows = self._rows_buffer[destination] table_reference = bigquery_tools.parse_table_reference(destination) if table_reference.projectId is None: table_reference.projectId = vp.RuntimeValueProvider.get_value( @@ -1668,9 +1668,10 @@ def _flush_batch(self, destination): _LOGGER.debug( 'Flushing data to %s. Total %s rows.', destination, - len(rows_and_insert_ids)) - self.batch_size_metric.update(len(rows_and_insert_ids)) + len(rows_and_insert_ids_with_windows)) + self.batch_size_metric.update(len(rows_and_insert_ids_with_windows)) + rows_and_insert_ids, window_values = zip(*rows_and_insert_ids_with_windows) rows = [r[0] for r in rows_and_insert_ids] if self.ignore_insert_ids: insert_ids = [None for r in rows_and_insert_ids] @@ -1689,8 +1690,10 @@ def _flush_batch(self, destination): ignore_unknown_values=self.ignore_unknown_columns) self.batch_latency_metric.update((time.time() - start) * 1000) - failed_rows = [(rows[entry['index']], entry["errors"]) + failed_rows = [( + rows[entry['index']], entry["errors"], window_values[entry['index']]) for entry in errors] + failed_insert_ids = [insert_ids[entry['index']] for entry in errors] retry_backoff = next(self._backoff_calculator, None) # If retry_backoff is None, then we will not retry and must log. @@ -1721,7 +1724,11 @@ def _flush_batch(self, destination): _LOGGER.info( 'Sleeping %s seconds before retrying insertion.', retry_backoff) time.sleep(retry_backoff) + # We can now safely discard all information about successful rows and + # just focus on the failed ones rows = [fr[0] for fr in failed_rows] + window_values = [fr[2] for fr in failed_rows] + insert_ids = failed_insert_ids self._throttled_secs.inc(retry_backoff) self._total_buffered_rows -= len(self._rows_buffer[destination]) @@ -1729,19 +1736,21 @@ def _flush_batch(self, destination): if destination in self._destination_buffer_byte_size: del self._destination_buffer_byte_size[destination] - return itertools.chain([ - pvalue.TaggedOutput( - BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS, - GlobalWindows.windowed_value((destination, row, err))) for row, - err in failed_rows - ], - [ - pvalue.TaggedOutput( - BigQueryWriteFn.FAILED_ROWS, - GlobalWindows.windowed_value( - (destination, row))) for row, - unused_err in failed_rows - ]) + return itertools.chain( + [ + pvalue.TaggedOutput( + BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS, + w.with_value((destination, row, err))) for row, + err, + w in failed_rows + ], + [ + pvalue.TaggedOutput( + BigQueryWriteFn.FAILED_ROWS, w.with_value((destination, row))) + for row, + unused_err, + w in failed_rows + ]) # The number of shards per destination when writing via streaming inserts. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py index b0140793cf79..cd3edf19de5f 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py @@ -491,6 +491,7 @@ def test_big_query_write_insert_non_transient_api_call_error(self): # pylint: disable=expression-not-assigned errors = ( p | 'create' >> beam.Create(input_data) + | beam.WindowInto(beam.transforms.window.FixedWindows(10)) | 'write' >> beam.io.WriteToBigQuery( table_id, schema=table_schema,