Skip to content

Commit

Permalink
Force BQIO to output elements in the correct row (apache#32584)
Browse files Browse the repository at this point in the history
* Fix bqio

* import fix

* syntax

* feedback
  • Loading branch information
damccorm authored Sep 30, 2024
1 parent 48f836a commit e640b25
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 2
"modification": 3
}

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 1
}
59 changes: 34 additions & 25 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def chain_after(result):
from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX
from apache_beam.transforms.sideinputs import get_sideinput_index
from apache_beam.transforms.util import ReshufflePerKey
from apache_beam.transforms.window import GlobalWindows
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import retry
Expand Down Expand Up @@ -1581,7 +1580,8 @@ def _create_table_if_needed(self, table_reference, schema=None):
additional_create_parameters=self.additional_bq_parameters)
_KNOWN_TABLES.add(str_table_reference)

def process(self, element, *schema_side_inputs):
def process(
self, element, window_value=DoFn.WindowedValueParam, *schema_side_inputs):
destination = bigquery_tools.get_hashable_destination(element[0])

if callable(self.schema):
Expand All @@ -1608,12 +1608,11 @@ def process(self, element, *schema_side_inputs):
return [
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
GlobalWindows.windowed_value(
window_value.with_value(
(destination, row_and_insert_id[0], error))),
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS,
GlobalWindows.windowed_value(
(destination, row_and_insert_id[0])))
window_value.with_value((destination, row_and_insert_id[0])))
]

# Flush current batch first if adding this row will exceed our limits
Expand All @@ -1624,19 +1623,20 @@ def process(self, element, *schema_side_inputs):
flushed_batch = self._flush_batch(destination)
# After flushing our existing batch, we now buffer the current row
# for the next flush
self._rows_buffer[destination].append(row_and_insert_id)
self._rows_buffer[destination].append((row_and_insert_id, window_value))
self._destination_buffer_byte_size[destination] = row_byte_size
return flushed_batch

self._rows_buffer[destination].append(row_and_insert_id)
self._rows_buffer[destination].append((row_and_insert_id, window_value))
self._destination_buffer_byte_size[destination] += row_byte_size
self._total_buffered_rows += 1
if self._total_buffered_rows >= self._max_buffered_rows:
return self._flush_all_batches()
else:
# The input is already batched per destination, flush the rows now.
batched_rows = element[1]
self._rows_buffer[destination].extend(batched_rows)
for r in batched_rows:
self._rows_buffer[destination].append((r, window_value))
return self._flush_batch(destination)

def finish_bundle(self):
Expand All @@ -1659,7 +1659,7 @@ def _flush_all_batches(self):
def _flush_batch(self, destination):

# Flush the current batch of rows to BigQuery.
rows_and_insert_ids = self._rows_buffer[destination]
rows_and_insert_ids_with_windows = self._rows_buffer[destination]
table_reference = bigquery_tools.parse_table_reference(destination)
if table_reference.projectId is None:
table_reference.projectId = vp.RuntimeValueProvider.get_value(
Expand All @@ -1668,9 +1668,10 @@ def _flush_batch(self, destination):
_LOGGER.debug(
'Flushing data to %s. Total %s rows.',
destination,
len(rows_and_insert_ids))
self.batch_size_metric.update(len(rows_and_insert_ids))
len(rows_and_insert_ids_with_windows))
self.batch_size_metric.update(len(rows_and_insert_ids_with_windows))

rows_and_insert_ids, window_values = zip(*rows_and_insert_ids_with_windows)
rows = [r[0] for r in rows_and_insert_ids]
if self.ignore_insert_ids:
insert_ids = [None for r in rows_and_insert_ids]
Expand All @@ -1689,8 +1690,10 @@ def _flush_batch(self, destination):
ignore_unknown_values=self.ignore_unknown_columns)
self.batch_latency_metric.update((time.time() - start) * 1000)

failed_rows = [(rows[entry['index']], entry["errors"])
failed_rows = [(
rows[entry['index']], entry["errors"], window_values[entry['index']])
for entry in errors]
failed_insert_ids = [insert_ids[entry['index']] for entry in errors]
retry_backoff = next(self._backoff_calculator, None)

# If retry_backoff is None, then we will not retry and must log.
Expand Down Expand Up @@ -1721,27 +1724,33 @@ def _flush_batch(self, destination):
_LOGGER.info(
'Sleeping %s seconds before retrying insertion.', retry_backoff)
time.sleep(retry_backoff)
# We can now safely discard all information about successful rows and
# just focus on the failed ones
rows = [fr[0] for fr in failed_rows]
window_values = [fr[2] for fr in failed_rows]
insert_ids = failed_insert_ids
self._throttled_secs.inc(retry_backoff)

self._total_buffered_rows -= len(self._rows_buffer[destination])
del self._rows_buffer[destination]
if destination in self._destination_buffer_byte_size:
del self._destination_buffer_byte_size[destination]

return itertools.chain([
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
GlobalWindows.windowed_value((destination, row, err))) for row,
err in failed_rows
],
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS,
GlobalWindows.windowed_value(
(destination, row))) for row,
unused_err in failed_rows
])
return itertools.chain(
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
w.with_value((destination, row, err))) for row,
err,
w in failed_rows
],
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS, w.with_value((destination, row)))
for row,
unused_err,
w in failed_rows
])


# The number of shards per destination when writing via streaming inserts.
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def test_big_query_write_insert_non_transient_api_call_error(self):
# pylint: disable=expression-not-assigned
errors = (
p | 'create' >> beam.Create(input_data)
| beam.WindowInto(beam.transforms.window.FixedWindows(10))
| 'write' >> beam.io.WriteToBigQuery(
table_id,
schema=table_schema,
Expand Down

0 comments on commit e640b25

Please sign in to comment.