diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index baae205d9..0f78520f1 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -24,6 +24,7 @@ from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _convert_retry_deadline from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _exponential_sleep_generator # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -89,14 +90,13 @@ def __init__( bt_exceptions._MutateRowsIncomplete, ) # build retryable operation - retry = retries.AsyncRetry( + retry_wrapped = functools.partial( + retries.retry_target, + target=self._run_attempt, predicate=self.is_retryable, + sleep_generator=_exponential_sleep_generator(), timeout=operation_timeout, - initial=0.01, - multiplier=2, - maximum=60, ) - retry_wrapped = retry(self._run_attempt) self._operation = _convert_retry_deadline( retry_wrapped, operation_timeout, is_async=True ) @@ -104,6 +104,7 @@ def __init__( self.timeout_generator = _attempt_timeout_generator( attempt_timeout, operation_timeout ) + self.operation_timeout = operation_timeout self.mutations = mutation_entries self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 20b5618ea..bd0a50e04 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -29,10 +29,10 @@ from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _exponential_sleep_generator from google.api_core import retry_async as retries from google.api_core.retry_streaming_async import retry_target_stream -from google.api_core.retry import exponential_sleep_generator from google.api_core import exceptions as core_exceptions if TYPE_CHECKING: @@ -107,7 +107,7 @@ def start_operation(self) -> AsyncGenerator[Row, None]: return retry_target_stream( self._read_rows_attempt, self._predicate, - exponential_sleep_generator(0.01, 60, multiplier=2), + _exponential_sleep_generator(), self.operation_timeout, exception_factory=self._build_exception, ) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 8524cd9aa..0f17c4929 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -32,6 +32,7 @@ import random import os +from functools import partial from collections import namedtuple from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta @@ -62,9 +63,10 @@ from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _convert_retry_deadline from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _exponential_sleep_generator from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.row_filters import RowFilter @@ -835,16 +837,6 @@ def on_error_fn(exc): if predicate(exc): transient_errors.append(exc) - retry = retries.AsyncRetry( - predicate=predicate, - timeout=operation_timeout, - initial=0.01, - multiplier=2, - maximum=60, - on_error=on_error_fn, - is_stream=False, - ) - # prepare request metadata = _make_metadata(self.table_name, self.app_profile_id) @@ -857,8 +849,16 @@ async def execute_rpc(): ) return [(s.row_key, s.offset_bytes) async for s in results] + retry_wrapped = partial( + retries.retry_target, + target=execute_rpc, + predicate=predicate, + on_error=on_error_fn, + sleep_generator=_exponential_sleep_generator(), + timeout=operation_timeout, + ) wrapped_fn = _convert_retry_deadline( - retry(execute_rpc), operation_timeout, transient_errors, is_async=True + retry_wrapped, operation_timeout, transient_errors, is_async=True ) return await wrapped_fn() @@ -973,25 +973,29 @@ def on_error_fn(exc): if predicate(exc): transient_errors.append(exc) - retry = retries.AsyncRetry( + # create gapic request + gapic_fn = partial( + self.client._gapic_client.mutate_row, + request, + timeout=attempt_timeout, + metadata=_make_metadata(self.table_name, self.app_profile_id), + retry=None, + ) + # wrap rpc in retry logic + retry_wrapped = partial( + retries.retry_target, + target=gapic_fn, predicate=predicate, on_error=on_error_fn, + sleep_generator=_exponential_sleep_generator(), timeout=operation_timeout, - initial=0.01, - multiplier=2, - maximum=60, ) - # wrap rpc in retry logic - retry_wrapped = retry(self.client._gapic_client.mutate_row) # convert RetryErrors from retry wrapper into DeadlineExceeded errors deadline_wrapped = _convert_retry_deadline( retry_wrapped, operation_timeout, transient_errors, is_async=True ) - metadata = _make_metadata(self.table_name, self.app_profile_id) # trigger rpc - await deadline_wrapped( - request, timeout=attempt_timeout, metadata=metadata, retry=None - ) + await deadline_wrapped() async def bulk_mutate_rows( self, diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 1f8a63d21..af81846b9 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -15,6 +15,7 @@ from typing import Callable, Any import time +import random from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import RetryExceptionGroup @@ -62,6 +63,51 @@ def _attempt_timeout_generator( yield max(0, min(per_request_timeout, deadline - time.monotonic())) +def _exponential_sleep_generator( + initial: float = 0.01, + maximum: float = 60, + multiplier: float = 2, + min_increase: float = 0.01, +): + """ + Generates sleep intervals for exponential backoff on failed rpc attempts. + + Based on google.api_core.retry.exponential_sleep_generator, + but with the added constraint that each sleep interval must be strictly + greater than the previous one. + + Args: + initial: The starting delay value, in seconds. Subsequent values will + always be less than this value. Must be > 0. + maximum: The maximum amount of time to delay, in seconds. Must be + >= initial. + multiplier: The multiplier applied to the delay. Modifies the upper range + of sleep values that may be returned. Must be >= 1. + min_increase: The minimum amount of time to increase the delay, + in seconds. Modifies the lower range of sleep values that may be + returned. Min_increase will not be applied if it would cause the + value to exceed maximum. Must be >= 0. + Yields: + float: successive sleep intervals for exponential backoff, in seconds. + """ + if initial <= 0: + raise ValueError("initial must be > 0") + if multiplier < 1: + raise ValueError("multiplier must be >= 1") + if maximum < initial: + raise ValueError("maximum must be >= initial") + if min_increase < 0: + raise ValueError("min_increase must be >= 0") + lower_bound = initial + upper_bound = initial + next_sleep = initial + while True: + yield next_sleep + lower_bound = min(next_sleep + min_increase, maximum) + upper_bound = min(max(upper_bound * multiplier, lower_bound), maximum) + next_sleep = random.uniform(lower_bound, upper_bound) + + # TODO:replace this function with an exception_factory passed into the retry when # feature is merged: # https://github.com/googleapis/python-bigtable/blob/ea5b4f923e42516729c57113ddbe28096841b952/google/cloud/bigtable/data/_async/_read_rows.py#L130 diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 08bc397c3..bafc8ff08 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -97,6 +97,112 @@ def test_attempt_timeout_w_sleeps(self): expected_value -= sleep_time +class TestExponentialSleepGenerator: + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: m) + @pytest.mark.parametrize( + "args,expected", + [ + ((), [0.01, 0.02, 0.03, 0.04, 0.05]), # test defaults + ((1, 3, 2, 1), [1, 2, 3, 3, 3]), # test hitting limit + ((1, 3, 2, 0.5), [1, 1.5, 2, 2.5, 3, 3]), # test with smaller min_increase + ((0.92, 3, 2, 0), [0.92, 0.92, 0.92]), # test with min_increase of 0 + ((1, 3, 10, 0.5), [1, 1.5, 2, 2.5, 3, 3]), # test with larger multiplier + ((1, 25, 1.5, 5), [1, 6, 11, 16, 21, 25]), # test with larger min increase + ((1, 5, 1, 0), [1, 1, 1, 1]), # test with multiplier of 1 + ((1, 5, 1, 1), [1, 2, 3, 4]), # test with min_increase of 1 and multiplier of 1 + ], + ) + def test_exponential_sleep_generator_lower_bound(self, uniform, args, expected): + """ + Test that _exponential_sleep_generator generated expected values when random.uniform is mocked to return + the lower bound of the range + + Each yield should consistently be min_increase above the last + """ + import itertools + + gen = _helpers._exponential_sleep_generator(*args) + result = list(itertools.islice(gen, len(expected))) + assert result == expected + + @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n) + @pytest.mark.parametrize( + "args,expected", + [ + ((), [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28]), # test defaults + ((1, 3, 2, 1), [1, 2, 3, 3, 3]), # test hitting limit + ((1, 3, 2, 0.5), [1, 2, 3, 3]), # test with smaller min_increase + ((0.92, 3, 2, 0), [0.92, 1.84, 3, 3]), # test with min_increase of 0 + ((1, 5000, 10, 0.5), [1, 10, 100, 1000]), # test with larger multiplier + ((1, 20, 1.5, 5), [1, 6, 11, 16.5, 20]), # test with larger min increase + ((1, 5, 1, 0), [1, 1, 1, 1]), # test with multiplier of 1 + ((1, 5, 1, 1), [1, 2, 3, 4]), # test with min_increase of 1 and multiplier of 1 + ], + ) + def test_exponential_sleep_generator_upper_bound(self, uniform, args, expected): + """ + Test that _exponential_sleep_generator generated expected values when random.uniform is mocked to return + the upper bound of the range + + Each yield should be scaled by multiplier + """ + import itertools + + gen = _helpers._exponential_sleep_generator(*args) + result = list(itertools.islice(gen, len(expected))) + assert result == expected + + @pytest.mark.parametrize( + "kwargs,exc_msg", + [ + ({"initial": 0}, "initial must be > 0"), + ({"initial": -1}, "initial must be > 0"), + ({"multiplier": 0}, "multiplier must be >= 1"), + ({"multiplier": -1}, "multiplier must be >= 1"), + ({"multiplier": 0.9}, "multiplier must be >= 1"), + ({"min_increase": -1}, "min_increase must be >= 0"), + ({"min_increase": -0.1}, "min_increase must be >= 0"), + ({"initial": 1, "maximum": 0}, "maximum must be >= initial"), + ({"initial": 2, "maximum": 1}, "maximum must be >= initial"), + ({"initial": 2, "maximum": 1.99}, "maximum must be >= initial"), + ], + ) + def test_exponential_sleep_generator_bad_arguments(self, kwargs, exc_msg): + """ + Test that _exponential_sleep_generator raises ValueError when given unexpected 0 or negative values + """ + with pytest.raises(ValueError) as excinfo: + gen = _helpers._exponential_sleep_generator(**kwargs) + # start generator + next(gen) + assert exc_msg in str(excinfo.value) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + {"multiplier": 1}, + {"multiplier": 1.1}, + {"multiplier": 2}, + {"min_increase": 0}, + {"min_increase": 0.1}, + {"min_increase": 100}, + {"multiplier": 1, "min_increase": 0}, + {"multiplier": 1, "min_increase": 4}, + ], + ) + def test_exponential_sleep_generator_always_increases(self, kwargs): + """ + Generate a bunch of sleep values without random mocked, to ensure they always increase + """ + gen = _helpers._exponential_sleep_generator(**kwargs, maximum=float("inf")) + last = next(gen) + for i in range(100): + current = next(gen) + assert current >= last + last = current + + class TestConvertRetryDeadline: """ Test _convert_retry_deadline wrapper