Skip to content

Commit

Permalink
feat: expose retryable error codes to users (#879)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Dec 1, 2023
1 parent b191451 commit 285cdd3
Show file tree
Hide file tree
Showing 9 changed files with 514 additions and 204 deletions.
6 changes: 3 additions & 3 deletions google/cloud/bigtable/data/_async/_mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Sequence, TYPE_CHECKING
import asyncio
from dataclasses import dataclass
import functools
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(
mutation_entries: list["RowMutationEntry"],
operation_timeout: float,
attempt_timeout: float | None,
retryable_exceptions: Sequence[type[Exception]] = (),
):
"""
Args:
Expand Down Expand Up @@ -96,8 +97,7 @@ def __init__(
# create predicate for determining which errors are retryable
self.is_retryable = retries.if_exception_type(
# RPC level errors
core_exceptions.DeadlineExceeded,
core_exceptions.ServiceUnavailable,
*retryable_exceptions,
# Entry level errors
bt_exceptions._MutateRowsIncomplete,
)
Expand Down
15 changes: 9 additions & 6 deletions google/cloud/bigtable/data/_async/_read_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable
from typing import (
TYPE_CHECKING,
AsyncGenerator,
AsyncIterable,
Awaitable,
Sequence,
)

from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB
from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB
Expand Down Expand Up @@ -74,6 +80,7 @@ def __init__(
table: "TableAsync",
operation_timeout: float,
attempt_timeout: float,
retryable_exceptions: Sequence[type[Exception]] = (),
):
self.attempt_timeout_gen = _attempt_timeout_generator(
attempt_timeout, operation_timeout
Expand All @@ -88,11 +95,7 @@ def __init__(
else:
self.request = query._to_pb(table)
self.table = table
self._predicate = retries.if_exception_type(
core_exceptions.DeadlineExceeded,
core_exceptions.ServiceUnavailable,
core_exceptions.Aborted,
)
self._predicate = retries.if_exception_type(*retryable_exceptions)
self._metadata = _make_metadata(
table.table_name,
table.app_profile_id,
Expand Down
140 changes: 108 additions & 32 deletions google/cloud/bigtable/data/_async/client.py

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
from __future__ import annotations

from typing import Any, TYPE_CHECKING
from typing import Any, Sequence, TYPE_CHECKING
import asyncio
import atexit
import warnings
Expand All @@ -23,6 +23,7 @@
from google.cloud.bigtable.data.mutations import RowMutationEntry
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
from google.cloud.bigtable.data._helpers import _get_retryable_errors
from google.cloud.bigtable.data._helpers import _get_timeouts
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT

Expand Down Expand Up @@ -192,6 +193,8 @@ def __init__(
flow_control_max_bytes: int = 100 * _MB_SIZE,
batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
batch_retryable_errors: Sequence[type[Exception]]
| TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
):
"""
Args:
Expand All @@ -208,10 +211,16 @@ def __init__(
- batch_attempt_timeout: timeout for each individual request, in seconds.
If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout.
If None, defaults to batch_operation_timeout.
- batch_retryable_errors: a list of errors that will be retried if encountered.
Defaults to the Table's default_mutate_rows_retryable_errors.
"""
self._operation_timeout, self._attempt_timeout = _get_timeouts(
batch_operation_timeout, batch_attempt_timeout, table
)
self._retryable_errors: list[type[Exception]] = _get_retryable_errors(
batch_retryable_errors, table
)

self.closed: bool = False
self._table = table
self._staged_entries: list[RowMutationEntry] = []
Expand Down Expand Up @@ -349,6 +358,7 @@ async def _execute_mutate_rows(
batch,
operation_timeout=self._operation_timeout,
attempt_timeout=self._attempt_timeout,
retryable_exceptions=self._retryable_errors,
)
await operation.start()
except MutationsExceptionGroup as e:
Expand Down
31 changes: 29 additions & 2 deletions google/cloud/bigtable/data/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Helper functions used in various places in the library.
"""
from __future__ import annotations

from typing import Callable, List, Tuple, Any
from typing import Callable, Sequence, List, Tuple, Any, TYPE_CHECKING
import time
import enum
from collections import namedtuple
Expand All @@ -22,6 +25,10 @@
from google.api_core import exceptions as core_exceptions
from google.cloud.bigtable.data.exceptions import RetryExceptionGroup

if TYPE_CHECKING:
import grpc
from google.cloud.bigtable.data import TableAsync

"""
Helper functions used in various places in the library.
"""
Expand Down Expand Up @@ -142,7 +149,9 @@ def wrapper(*args, **kwargs):


def _get_timeouts(
operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, table
operation: float | TABLE_DEFAULT,
attempt: float | None | TABLE_DEFAULT,
table: "TableAsync",
) -> tuple[float, float]:
"""
Convert passed in timeout values to floats, using table defaults if necessary.
Expand Down Expand Up @@ -209,3 +218,21 @@ def _validate_timeouts(
elif attempt_timeout is not None:
if attempt_timeout <= 0:
raise ValueError("attempt_timeout must be greater than 0")


def _get_retryable_errors(
call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT,
table: "TableAsync",
) -> list[type[Exception]]:
# load table defaults if necessary
if call_codes == TABLE_DEFAULT.DEFAULT:
call_codes = table.default_retryable_errors
elif call_codes == TABLE_DEFAULT.READ_ROWS:
call_codes = table.default_read_rows_retryable_errors
elif call_codes == TABLE_DEFAULT.MUTATE_ROWS:
call_codes = table.default_mutate_rows_retryable_errors

return [
e if isinstance(e, type) else type(core_exceptions.from_grpc_status(e, ""))
for e in call_codes
]
26 changes: 19 additions & 7 deletions tests/unit/data/_async/test__mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def _make_one(self, *args, **kwargs):
if not args:
kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock())
kwargs["table"] = kwargs.pop("table", AsyncMock())
kwargs["mutation_entries"] = kwargs.pop("mutation_entries", [])
kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5)
kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1)
kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ())
kwargs["mutation_entries"] = kwargs.pop("mutation_entries", [])
return self._target_class()(*args, **kwargs)

async def _mock_stream(self, mutation_list, error_dict):
Expand Down Expand Up @@ -78,15 +79,21 @@ def test_ctor(self):
from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto
from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
from google.api_core.exceptions import DeadlineExceeded
from google.api_core.exceptions import ServiceUnavailable
from google.api_core.exceptions import Aborted

client = mock.Mock()
table = mock.Mock()
entries = [_make_mutation(), _make_mutation()]
operation_timeout = 0.05
attempt_timeout = 0.01
retryable_exceptions = ()
instance = self._make_one(
client, table, entries, operation_timeout, attempt_timeout
client,
table,
entries,
operation_timeout,
attempt_timeout,
retryable_exceptions,
)
# running gapic_fn should trigger a client call
assert client.mutate_rows.call_count == 0
Expand All @@ -110,8 +117,8 @@ def test_ctor(self):
assert next(instance.timeout_generator) == attempt_timeout
# ensure predicate is set
assert instance.is_retryable is not None
assert instance.is_retryable(DeadlineExceeded("")) is True
assert instance.is_retryable(ServiceUnavailable("")) is True
assert instance.is_retryable(DeadlineExceeded("")) is False
assert instance.is_retryable(Aborted("")) is False
assert instance.is_retryable(_MutateRowsIncomplete("")) is True
assert instance.is_retryable(RuntimeError("")) is False
assert instance.remaining_indices == list(range(len(entries)))
Expand Down Expand Up @@ -232,7 +239,7 @@ async def test_mutate_rows_exception(self, exc_type):

@pytest.mark.parametrize(
"exc_type",
[core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable],
[core_exceptions.DeadlineExceeded, RuntimeError],
)
@pytest.mark.asyncio
async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type):
Expand All @@ -256,7 +263,12 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type):
) as attempt_mock:
attempt_mock.side_effect = [expected_cause] * num_retries + [None]
instance = self._make_one(
client, table, entries, operation_timeout, operation_timeout
client,
table,
entries,
operation_timeout,
operation_timeout,
retryable_exceptions=(exc_type,),
)
await instance.start()
assert attempt_mock.call_count == num_retries + 1
Expand Down
Loading

0 comments on commit 285cdd3

Please sign in to comment.