Skip to content

Commit

Permalink
feat: optimize retries (#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Aug 17, 2023
1 parent 0b3606f commit b6d232a
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 64 deletions.
71 changes: 36 additions & 35 deletions google/cloud/bigtable/data/_async/_read_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from google.cloud.bigtable.data._helpers import _make_metadata

from google.api_core import retry_async as retries
from google.api_core.retry_streaming_async import AsyncRetryableGenerator
from google.api_core.retry_streaming_async import retry_target_stream
from google.api_core.retry import exponential_sleep_generator
from google.api_core import exceptions as core_exceptions

Expand Down Expand Up @@ -100,35 +100,17 @@ def __init__(
self._last_yielded_row_key: bytes | None = None
self._remaining_count: int | None = self.request.rows_limit or None

async def start_operation(self) -> AsyncGenerator[Row, None]:
def start_operation(self) -> AsyncGenerator[Row, None]:
"""
Start the read_rows operation, retrying on retryable errors.
"""
transient_errors = []

def on_error_fn(exc):
if self._predicate(exc):
transient_errors.append(exc)

retry_gen = AsyncRetryableGenerator(
return retry_target_stream(
self._read_rows_attempt,
self._predicate,
exponential_sleep_generator(0.01, 60, multiplier=2),
self.operation_timeout,
on_error_fn,
exception_factory=self._build_exception,
)
try:
async for row in retry_gen:
yield row
if self._remaining_count is not None:
self._remaining_count -= 1
if self._remaining_count < 0:
raise RuntimeError("emit count exceeds row limit")
except core_exceptions.RetryError:
self._raise_retry_error(transient_errors)
except GeneratorExit:
# propagate close to wrapped generator
await retry_gen.aclose()

def _read_rows_attempt(self) -> AsyncGenerator[Row, None]:
"""
Expand Down Expand Up @@ -202,6 +184,10 @@ async def chunk_stream(
elif c.commit_row:
# update row state after each commit
self._last_yielded_row_key = current_key
if self._remaining_count is not None:
self._remaining_count -= 1
if self._remaining_count < 0:
raise InvalidChunk("emit count exceeds row limit")
current_key = None

@staticmethod
Expand Down Expand Up @@ -354,19 +340,34 @@ def _revise_request_rowset(
raise _RowSetComplete()
return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges)

def _raise_retry_error(self, transient_errors: list[Exception]) -> None:
@staticmethod
def _build_exception(
exc_list: list[Exception], is_timeout: bool, timeout_val: float
) -> tuple[Exception, Exception | None]:
"""
If the retryable deadline is hit, wrap the raised exception
in a RetryExceptionGroup
Build retry error based on exceptions encountered during operation
Args:
- exc_list: list of exceptions encountered during operation
- is_timeout: whether the operation failed due to timeout
- timeout_val: the operation timeout value in seconds, for constructing
the error message
Returns:
- tuple of the exception to raise, and a cause exception if applicable
"""
timeout_value = self.operation_timeout
timeout_str = f" of {timeout_value:.1f}s" if timeout_value is not None else ""
error_str = f"operation_timeout{timeout_str} exceeded"
new_exc = core_exceptions.DeadlineExceeded(
error_str,
if is_timeout:
# if failed due to timeout, raise deadline exceeded as primary exception
source_exc: Exception = core_exceptions.DeadlineExceeded(
f"operation_timeout of {timeout_val} exceeded"
)
elif exc_list:
# otherwise, raise non-retryable error as primary exception
source_exc = exc_list.pop()
else:
source_exc = RuntimeError("failed with unspecified exception")
# use the retry exception group as the cause of the exception
cause_exc: Exception | None = (
RetryExceptionGroup(exc_list) if exc_list else None
)
source_exc = None
if transient_errors:
source_exc = RetryExceptionGroup(transient_errors)
new_exc.__cause__ = source_exc
raise new_exc from source_exc
source_exc.__cause__ = cause_exc
return source_exc, cause_exc
3 changes: 3 additions & 0 deletions google/cloud/bigtable/data/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def _attempt_timeout_generator(
yield max(0, min(per_request_timeout, deadline - time.monotonic()))


# TODO:replace this function with an exception_factory passed into the retry when
# feature is merged:
# https://github.com/googleapis/python-bigtable/blob/ea5b4f923e42516729c57113ddbe28096841b952/google/cloud/bigtable/data/_async/_read_rows.py#L130
def _convert_retry_deadline(
func: Callable[..., Any],
timeout_value: float | None = None,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# 'Development Status :: 5 - Production/Stable'
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
"google-api-core[grpc] == 2.12.0.dev0", # TODO: change to >= after streaming retries is merged
"google-api-core[grpc] == 2.12.0.dev1", # TODO: change to >= after streaming retries is merged
"google-cloud-core >= 1.4.1, <3.0.0dev",
"grpc-google-iam-v1 >= 0.12.4, <1.0.0dev",
"proto-plus >= 1.22.0, <2.0.0dev",
Expand Down
2 changes: 1 addition & 1 deletion testing/constraints-3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
# Then this file should have foo==1.14.0
google-api-core==2.12.0.dev0
google-api-core==2.12.0.dev1
google-cloud-core==2.3.2
grpc-google-iam-v1==0.12.4
proto-plus==1.22.0
Expand Down
74 changes: 48 additions & 26 deletions tests/unit/data/_async/test__read_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,24 +226,34 @@ async def test_revise_limit(self, start_limit, emit_num, expected_limit):
should be raised (tested in test_revise_limit_over_limit)
"""
from google.cloud.bigtable.data import ReadRowsQuery
from google.cloud.bigtable_v2.types import ReadRowsResponse

async def mock_stream():
for i in range(emit_num):
yield i
async def awaitable_stream():
async def mock_stream():
for i in range(emit_num):
yield ReadRowsResponse(
chunks=[
ReadRowsResponse.CellChunk(
row_key=str(i).encode(),
family_name="b",
qualifier=b"c",
value=b"d",
commit_row=True,
)
]
)

return mock_stream()

query = ReadRowsQuery(limit=start_limit)
table = mock.Mock()
table.table_name = "table_name"
table.app_profile_id = "app_profile_id"
with mock.patch.object(
_ReadRowsOperationAsync, "_read_rows_attempt"
) as mock_attempt:
mock_attempt.return_value = mock_stream()
instance = self._make_one(query, table, 10, 10)
assert instance._remaining_count == start_limit
# read emit_num rows
async for val in instance.start_operation():
pass
instance = self._make_one(query, table, 10, 10)
assert instance._remaining_count == start_limit
# read emit_num rows
async for val in instance.chunk_stream(awaitable_stream()):
pass
assert instance._remaining_count == expected_limit

@pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)])
Expand All @@ -254,26 +264,37 @@ async def test_revise_limit_over_limit(self, start_limit, emit_num):
(unless start_num == 0, which represents unlimited)
"""
from google.cloud.bigtable.data import ReadRowsQuery
from google.cloud.bigtable_v2.types import ReadRowsResponse
from google.cloud.bigtable.data.exceptions import InvalidChunk

async def mock_stream():
for i in range(emit_num):
yield i
async def awaitable_stream():
async def mock_stream():
for i in range(emit_num):
yield ReadRowsResponse(
chunks=[
ReadRowsResponse.CellChunk(
row_key=str(i).encode(),
family_name="b",
qualifier=b"c",
value=b"d",
commit_row=True,
)
]
)

return mock_stream()

query = ReadRowsQuery(limit=start_limit)
table = mock.Mock()
table.table_name = "table_name"
table.app_profile_id = "app_profile_id"
with mock.patch.object(
_ReadRowsOperationAsync, "_read_rows_attempt"
) as mock_attempt:
mock_attempt.return_value = mock_stream()
instance = self._make_one(query, table, 10, 10)
assert instance._remaining_count == start_limit
with pytest.raises(RuntimeError) as e:
# read emit_num rows
async for val in instance.start_operation():
pass
assert "emit count exceeds row limit" in str(e.value)
instance = self._make_one(query, table, 10, 10)
assert instance._remaining_count == start_limit
with pytest.raises(InvalidChunk) as e:
# read emit_num rows
async for val in instance.chunk_stream(awaitable_stream()):
pass
assert "emit count exceeds row limit" in str(e.value)

@pytest.mark.asyncio
async def test_aclose(self):
Expand Down Expand Up @@ -333,6 +354,7 @@ async def mock_stream():

instance = mock.Mock()
instance._last_yielded_row_key = None
instance._remaining_count = None
stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream())
await stream.__anext__()
with pytest.raises(InvalidChunk) as exc:
Expand Down

0 comments on commit b6d232a

Please sign in to comment.