diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 6edb72858..20b5618ea 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -31,7 +31,7 @@ from google.cloud.bigtable.data._helpers import _make_metadata from google.api_core import retry_async as retries -from google.api_core.retry_streaming_async import AsyncRetryableGenerator +from google.api_core.retry_streaming_async import retry_target_stream from google.api_core.retry import exponential_sleep_generator from google.api_core import exceptions as core_exceptions @@ -100,35 +100,17 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - async def start_operation(self) -> AsyncGenerator[Row, None]: + def start_operation(self) -> AsyncGenerator[Row, None]: """ Start the read_rows operation, retrying on retryable errors. """ - transient_errors = [] - - def on_error_fn(exc): - if self._predicate(exc): - transient_errors.append(exc) - - retry_gen = AsyncRetryableGenerator( + return retry_target_stream( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), self.operation_timeout, - on_error_fn, + exception_factory=self._build_exception, ) - try: - async for row in retry_gen: - yield row - if self._remaining_count is not None: - self._remaining_count -= 1 - if self._remaining_count < 0: - raise RuntimeError("emit count exceeds row limit") - except core_exceptions.RetryError: - self._raise_retry_error(transient_errors) - except GeneratorExit: - # propagate close to wrapped generator - await retry_gen.aclose() def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: """ @@ -202,6 +184,10 @@ async def chunk_stream( elif c.commit_row: # update row state after each commit self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") current_key = None @staticmethod @@ -354,19 +340,34 @@ def _revise_request_rowset( raise _RowSetComplete() return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) - def _raise_retry_error(self, transient_errors: list[Exception]) -> None: + @staticmethod + def _build_exception( + exc_list: list[Exception], is_timeout: bool, timeout_val: float + ) -> tuple[Exception, Exception | None]: """ - If the retryable deadline is hit, wrap the raised exception - in a RetryExceptionGroup + Build retry error based on exceptions encountered during operation + + Args: + - exc_list: list of exceptions encountered during operation + - is_timeout: whether the operation failed due to timeout + - timeout_val: the operation timeout value in seconds, for constructing + the error message + Returns: + - tuple of the exception to raise, and a cause exception if applicable """ - timeout_value = self.operation_timeout - timeout_str = f" of {timeout_value:.1f}s" if timeout_value is not None else "" - error_str = f"operation_timeout{timeout_str} exceeded" - new_exc = core_exceptions.DeadlineExceeded( - error_str, + if is_timeout: + # if failed due to timeout, raise deadline exceeded as primary exception + source_exc: Exception = core_exceptions.DeadlineExceeded( + f"operation_timeout of {timeout_val} exceeded" + ) + elif exc_list: + # otherwise, raise non-retryable error as primary exception + source_exc = exc_list.pop() + else: + source_exc = RuntimeError("failed with unspecified exception") + # use the retry exception group as the cause of the exception + cause_exc: Exception | None = ( + RetryExceptionGroup(exc_list) if exc_list else None ) - source_exc = None - if transient_errors: - source_exc = RetryExceptionGroup(transient_errors) - new_exc.__cause__ = source_exc - raise new_exc from source_exc + source_exc.__cause__ = cause_exc + return source_exc, cause_exc diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index b13b670d4..1f8a63d21 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -62,6 +62,9 @@ def _attempt_timeout_generator( yield max(0, min(per_request_timeout, deadline - time.monotonic())) +# TODO:replace this function with an exception_factory passed into the retry when +# feature is merged: +# https://github.com/googleapis/python-bigtable/blob/ea5b4f923e42516729c57113ddbe28096841b952/google/cloud/bigtable/data/_async/_read_rows.py#L130 def _convert_retry_deadline( func: Callable[..., Any], timeout_value: float | None = None, diff --git a/python-api-core b/python-api-core index a526d6593..a8cfa66b8 160000 --- a/python-api-core +++ b/python-api-core @@ -1 +1 @@ -Subproject commit a526d659320939cd7f47ee775b250e8a3e3ab16b +Subproject commit a8cfa66b8d6001da56823c6488b5da4957e5702b diff --git a/setup.py b/setup.py index e05b37c79..e5efc9937 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ # 'Development Status :: 5 - Production/Stable' release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-api-core[grpc] == 2.12.0.dev0", # TODO: change to >= after streaming retries is merged + "google-api-core[grpc] == 2.12.0.dev1", # TODO: change to >= after streaming retries is merged "google-cloud-core >= 1.4.1, <3.0.0dev", "grpc-google-iam-v1 >= 0.12.4, <1.0.0dev", "proto-plus >= 1.22.0, <2.0.0dev", diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 92b616563..9f23121d1 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -5,7 +5,7 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 -google-api-core==2.12.0.dev0 +google-api-core==2.12.0.dev1 google-cloud-core==2.3.2 grpc-google-iam-v1==0.12.4 proto-plus==1.22.0 diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 200defbbf..4e7797c6d 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -226,24 +226,34 @@ async def test_revise_limit(self, start_limit, emit_num, expected_limit): should be raised (tested in test_revise_limit_over_limit) """ from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse - async def mock_stream(): - for i in range(emit_num): - yield i + async def awaitable_stream(): + async def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() query = ReadRowsQuery(limit=start_limit) table = mock.Mock() table.table_name = "table_name" table.app_profile_id = "app_profile_id" - with mock.patch.object( - _ReadRowsOperationAsync, "_read_rows_attempt" - ) as mock_attempt: - mock_attempt.return_value = mock_stream() - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - # read emit_num rows - async for val in instance.start_operation(): - pass + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + # read emit_num rows + async for val in instance.chunk_stream(awaitable_stream()): + pass assert instance._remaining_count == expected_limit @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) @@ -254,26 +264,37 @@ async def test_revise_limit_over_limit(self, start_limit, emit_num): (unless start_num == 0, which represents unlimited) """ from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + from google.cloud.bigtable.data.exceptions import InvalidChunk - async def mock_stream(): - for i in range(emit_num): - yield i + async def awaitable_stream(): + async def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() query = ReadRowsQuery(limit=start_limit) table = mock.Mock() table.table_name = "table_name" table.app_profile_id = "app_profile_id" - with mock.patch.object( - _ReadRowsOperationAsync, "_read_rows_attempt" - ) as mock_attempt: - mock_attempt.return_value = mock_stream() - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - with pytest.raises(RuntimeError) as e: - # read emit_num rows - async for val in instance.start_operation(): - pass - assert "emit count exceeds row limit" in str(e.value) + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + with pytest.raises(InvalidChunk) as e: + # read emit_num rows + async for val in instance.chunk_stream(awaitable_stream()): + pass + assert "emit count exceeds row limit" in str(e.value) @pytest.mark.asyncio async def test_aclose(self): @@ -333,6 +354,7 @@ async def mock_stream(): instance = mock.Mock() instance._last_yielded_row_key = None + instance._remaining_count = None stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: