diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 4b01d0e6b..a68be5417 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -13,9 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from typing import List, Tuple - from google.cloud.bigtable import gapic_version as package_version from google.cloud.bigtable.data._async.client import BigtableDataClientAsync @@ -44,10 +41,10 @@ from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup -# Type alias for the output of sample_keys -RowKeySamples = List[Tuple[bytes, int]] -# type alias for the output of query.shard() -ShardedQuery = List[ReadRowsQuery] +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import RowKeySamples +from google.cloud.bigtable.data._helpers import ShardedQuery + __version__: str = package_version.__version__ @@ -74,4 +71,5 @@ "MutationsExceptionGroup", "ShardedReadRowsExceptionGroup", "ShardedQuery", + "TABLE_DEFAULT", ) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index e5be1b2d3..c6637581c 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -32,7 +32,6 @@ import random import os -from collections import namedtuple from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient @@ -59,30 +58,26 @@ from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _WarmedInstanceKey +from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _convert_retry_deadline from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator - from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain -if TYPE_CHECKING: - from google.cloud.bigtable.data import RowKeySamples - from google.cloud.bigtable.data import ShardedQuery - -# used by read_rows_sharded to limit how many requests are attempted in parallel -_CONCURRENCY_LIMIT = 10 -# used to register instance data with the client for channel warming -_WarmedInstanceKey = namedtuple( - "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] -) +if TYPE_CHECKING: + from google.cloud.bigtable.data._helpers import RowKeySamples + from google.cloud.bigtable.data._helpers import ShardedQuery class BigtableDataClientAsync(ClientWithProject): @@ -525,8 +520,8 @@ async def read_rows_stream( self, query: ReadRowsQuery, *, - operation_timeout: float | None = None, - attempt_timeout: float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> AsyncIterable[Row]: """ Read a set of rows from the table, based on the specified query. @@ -538,12 +533,12 @@ async def read_rows_stream( - query: contains details about which rows to return - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. - If None, defaults to the Table's default_read_rows_operation_timeout + Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_read_rows_attempt_timeout, - or the operation_timeout if that is also None. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. Returns: - an asynchronous iterator that yields rows returned by the query Raises: @@ -553,15 +548,9 @@ async def read_rows_stream( - GoogleAPIError: raised if the request encounters an unrecoverable error - IdleTimeout: if iterator was abandoned """ - operation_timeout = ( - operation_timeout or self.default_read_rows_operation_timeout - ) - attempt_timeout = ( - attempt_timeout - or self.default_read_rows_attempt_timeout - or operation_timeout + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self ) - _validate_timeouts(operation_timeout, attempt_timeout) row_merger = _ReadRowsOperationAsync( query, @@ -575,8 +564,8 @@ async def read_rows( self, query: ReadRowsQuery, *, - operation_timeout: float | None = None, - attempt_timeout: float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """ Read a set of rows from the table, based on the specified query. @@ -589,12 +578,12 @@ async def read_rows( - query: contains details about which rows to return - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. - If None, defaults to the Table's default_read_rows_operation_timeout + Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_read_rows_attempt_timeout, - or the operation_timeout if that is also None. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. Returns: - a list of Rows returned by the query Raises: @@ -615,8 +604,8 @@ async def read_row( row_key: str | bytes, *, row_filter: RowFilter | None = None, - operation_timeout: int | float | None = None, - attempt_timeout: int | float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Row | None: """ Read a single row from the table, based on the specified key. @@ -627,12 +616,12 @@ async def read_row( - query: contains details about which rows to return - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. - If None, defaults to the Table's default_read_rows_operation_timeout + Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_read_rows_attempt_timeout, or the operation_timeout - if that is also None. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. Returns: - a Row object if the row exists, otherwise None Raises: @@ -657,8 +646,8 @@ async def read_rows_sharded( self, sharded_query: ShardedQuery, *, - operation_timeout: int | float | None = None, - attempt_timeout: int | float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """ Runs a sharded query in parallel, then return the results in a single list. @@ -677,12 +666,12 @@ async def read_rows_sharded( - sharded_query: a sharded query to execute - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. - If None, defaults to the Table's default_read_rows_operation_timeout + Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_read_rows_attempt_timeout, or the operation_timeout - if that is also None. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. Raises: - ShardedReadRowsExceptionGroup: if any of the queries failed - ValueError: if the query_list is empty @@ -690,15 +679,9 @@ async def read_rows_sharded( if not sharded_query: raise ValueError("empty sharded_query") # reduce operation_timeout between batches - operation_timeout = ( - operation_timeout or self.default_read_rows_operation_timeout + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self ) - attempt_timeout = ( - attempt_timeout - or self.default_read_rows_attempt_timeout - or operation_timeout - ) - _validate_timeouts(operation_timeout, attempt_timeout) timeout_generator = _attempt_timeout_generator( operation_timeout, operation_timeout ) @@ -744,8 +727,8 @@ async def row_exists( self, row_key: str | bytes, *, - operation_timeout: int | float | None = None, - attempt_timeout: int | float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> bool: """ Return a boolean indicating whether the specified row exists in the table. @@ -754,12 +737,12 @@ async def row_exists( - row_key: the key of the row to check - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. - If None, defaults to the Table's default_read_rows_operation_timeout + Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_read_rows_attempt_timeout, or the operation_timeout - if that is also None. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. Returns: - a bool indicating whether the row exists Raises: @@ -785,8 +768,8 @@ async def row_exists( async def sample_row_keys( self, *, - operation_timeout: float | None = None, - attempt_timeout: float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> RowKeySamples: """ Return a set of RowKeySamples that delimit contiguous sections of the table of @@ -801,13 +784,13 @@ async def sample_row_keys( Args: - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - If None, defaults to the Table's default_operation_timeout + Failed requests will be retried within the budget.i + Defaults to the Table's default_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_attempt_timeout, or the operation_timeout - if that is also None. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. Returns: - a set of RowKeySamples the delimit contiguous sections of the table Raises: @@ -817,12 +800,9 @@ async def sample_row_keys( - GoogleAPIError: raised if the request encounters an unrecoverable error """ # prepare timeouts - operation_timeout = operation_timeout or self.default_operation_timeout - attempt_timeout = ( - attempt_timeout or self.default_attempt_timeout or operation_timeout + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self ) - _validate_timeouts(operation_timeout, attempt_timeout) - attempt_timeout_gen = _attempt_timeout_generator( attempt_timeout, operation_timeout ) @@ -873,8 +853,8 @@ def mutations_batcher( flush_limit_bytes: int = 20 * _MB_SIZE, flow_control_max_mutation_count: int = 100_000, flow_control_max_bytes: int = 100 * _MB_SIZE, - batch_operation_timeout: float | None = None, - batch_attempt_timeout: float | None = None, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ) -> MutationsBatcherAsync: """ Returns a new mutations batcher instance. @@ -890,11 +870,11 @@ def mutations_batcher( - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - flow_control_max_mutation_count: Maximum number of inflight mutations. - flow_control_max_bytes: Maximum number of inflight bytes. - - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. If None, - table default_mutate_rows_operation_timeout will be used - - batch_attempt_timeout: timeout for each individual request, in seconds. If None, - table default_mutate_rows_attempt_timeout will be used, or batch_operation_timeout - if that is also None. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + Defaults to the Table's default_mutate_rows_operation_timeout + - batch_attempt_timeout: timeout for each individual request, in seconds. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. Returns: - a MutationsBatcherAsync context manager that can batch requests """ @@ -914,8 +894,8 @@ async def mutate_row( row_key: str | bytes, mutations: list[Mutation] | Mutation, *, - operation_timeout: float | None = None, - attempt_timeout: float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ): """ Mutates a row atomically. @@ -931,12 +911,12 @@ async def mutate_row( - mutations: the set of mutations to apply to the row - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. - If None, defaults to the Table's default_operation_timeout + Defaults to the Table's default_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_attempt_timeout, or the operation_timeout - if that is also None. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. Raises: - DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing all @@ -944,11 +924,9 @@ async def mutate_row( - GoogleAPIError: raised on non-idempotent operations that cannot be safely retried. """ - operation_timeout = operation_timeout or self.default_operation_timeout - attempt_timeout = ( - attempt_timeout or self.default_attempt_timeout or operation_timeout + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self ) - _validate_timeouts(operation_timeout, attempt_timeout) if isinstance(row_key, str): row_key = row_key.encode("utf-8") @@ -1000,8 +978,8 @@ async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], *, - operation_timeout: float | None = None, - attempt_timeout: float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): """ Applies mutations for multiple rows in a single batched request. @@ -1021,25 +999,19 @@ async def bulk_mutate_rows( in arbitrary order - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. - If None, defaults to the Table's default_mutate_rows_operation_timeout + Defaults to the Table's default_mutate_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - If None, defaults to the Table's default_mutate_rows_attempt_timeout, - or the operation_timeout if that is also None. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to operation_timeout. Raises: - MutationsExceptionGroup if one or more mutations fails Contains details about any failed entries in .exceptions """ - operation_timeout = ( - operation_timeout or self.default_mutate_rows_operation_timeout - ) - attempt_timeout = ( - attempt_timeout - or self.default_mutate_rows_attempt_timeout - or operation_timeout + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self ) - _validate_timeouts(operation_timeout, attempt_timeout) operation = _MutateRowsOperationAsync( self.client._gapic_client, @@ -1057,7 +1029,7 @@ async def check_and_mutate_row( *, true_case_mutations: Mutation | list[Mutation] | None = None, false_case_mutations: Mutation | list[Mutation] | None = None, - operation_timeout: int | float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> bool: """ Mutates a row atomically based on the output of a predicate filter @@ -1086,15 +1058,12 @@ async def check_and_mutate_row( `true_case_mutations is empty, and at most 100000. - operation_timeout: the time budget for the entire operation, in seconds. Failed requests will not be retried. Defaults to the Table's default_operation_timeout - if None. Returns: - bool indicating whether the predicate was true or false Raises: - GoogleAPIError exceptions from grpc call """ - operation_timeout = operation_timeout or self.default_operation_timeout - if operation_timeout <= 0: - raise ValueError("operation_timeout must be greater than 0") + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key if true_case_mutations is not None and not isinstance( true_case_mutations, list @@ -1128,7 +1097,7 @@ async def read_modify_write_row( row_key: str | bytes, rules: ReadModifyWriteRule | list[ReadModifyWriteRule], *, - operation_timeout: int | float | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> Row: """ Reads and modifies a row atomically according to input ReadModifyWriteRules, @@ -1145,15 +1114,15 @@ async def read_modify_write_row( Rules are applied in order, meaning that earlier rules will affect the results of later ones. - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will not be retried. Defaults to the Table's default_operation_timeout - if None. + Failed requests will not be retried. + Defaults to the Table's default_operation_timeout. Returns: - Row: containing cell data that was modified as part of the operation Raises: - GoogleAPIError exceptions from grpc call """ - operation_timeout = operation_timeout or self.default_operation_timeout + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 34e1bfb5d..7ff5f9a0b 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -23,7 +23,8 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedMutationEntryError -from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import ( @@ -189,8 +190,8 @@ def __init__( flush_limit_bytes: int = 20 * _MB_SIZE, flow_control_max_mutation_count: int = 100_000, flow_control_max_bytes: int = 100 * _MB_SIZE, - batch_operation_timeout: float | None = None, - batch_attempt_timeout: float | None = None, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): """ Args: @@ -202,21 +203,15 @@ def __init__( - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - flow_control_max_mutation_count: Maximum number of inflight mutations. - flow_control_max_bytes: Maximum number of inflight bytes. - - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. If None, - table default_mutate_rows_operation_timeout will be used - - batch_attempt_timeout: timeout for each individual request, in seconds. If None, - table default_mutate_rows_attempt_timeout will be used, or batch_operation_timeout - if that is also None. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. + - batch_attempt_timeout: timeout for each individual request, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. """ - self._operation_timeout: float = ( - batch_operation_timeout or table.default_mutate_rows_operation_timeout + self._operation_timeout, self._attempt_timeout = _get_timeouts( + batch_operation_timeout, batch_attempt_timeout, table ) - self._attempt_timeout: float = ( - batch_attempt_timeout - or table.default_mutate_rows_attempt_timeout - or self._operation_timeout - ) - _validate_timeouts(self._operation_timeout, self._attempt_timeout) self.closed: bool = False self._table = table self._staged_entries: list[RowMutationEntry] = [] diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 1f8a63d21..1d56926ff 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -13,8 +13,11 @@ # from __future__ import annotations -from typing import Callable, Any +from typing import Callable, List, Tuple, Any import time +import enum +from collections import namedtuple +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import RetryExceptionGroup @@ -23,6 +26,30 @@ Helper functions used in various places in the library. """ +# Type alias for the output of sample_keys +RowKeySamples = List[Tuple[bytes, int]] + +# type alias for the output of query.shard() +ShardedQuery = List[ReadRowsQuery] + +# used by read_rows_sharded to limit how many requests are attempted in parallel +_CONCURRENCY_LIMIT = 10 + +# used to register instance data with the client for channel warming +_WarmedInstanceKey = namedtuple( + "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] +) + + +# enum used on method calls when table defaults should be used +class TABLE_DEFAULT(enum.Enum): + # default for mutate_row, sample_row_keys, check_and_mutate_row, and read_modify_write_row + DEFAULT = "DEFAULT" + # default for read_rows, read_rows_stream, read_rows_sharded, row_exists, and read_row + READ_ROWS = "READ_ROWS_DEFAULT" + # default for bulk_mutate_rows and mutations_batcher + MUTATE_ROWS = "MUTATE_ROWS_DEFAULT" + def _make_metadata( table_name: str, app_profile_id: str | None @@ -114,6 +141,51 @@ def wrapper(*args, **kwargs): return wrapper_async if is_async else wrapper +def _get_timeouts( + operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, table +) -> tuple[float, float]: + """ + Convert passed in timeout values to floats, using table defaults if necessary. + + attempt will use operation value if None, or if larger than operation. + + Will call _validate_timeouts on the outputs, and raise ValueError if the + resulting timeouts are invalid. + + Args: + - operation: The timeout value to use for the entire operation, in seconds. + - attempt: The timeout value to use for each attempt, in seconds. + - table: The table to use for default values. + Returns: + - A tuple of (operation_timeout, attempt_timeout) + """ + # load table defaults if necessary + if operation == TABLE_DEFAULT.DEFAULT: + final_operation = table.default_operation_timeout + elif operation == TABLE_DEFAULT.READ_ROWS: + final_operation = table.default_read_rows_operation_timeout + elif operation == TABLE_DEFAULT.MUTATE_ROWS: + final_operation = table.default_mutate_rows_operation_timeout + else: + final_operation = operation + if attempt == TABLE_DEFAULT.DEFAULT: + attempt = table.default_attempt_timeout + elif attempt == TABLE_DEFAULT.READ_ROWS: + attempt = table.default_read_rows_attempt_timeout + elif attempt == TABLE_DEFAULT.MUTATE_ROWS: + attempt = table.default_mutate_rows_attempt_timeout + + if attempt is None: + # no timeout specified, use operation timeout for both + final_attempt = final_operation + else: + # cap attempt timeout at operation timeout + final_attempt = min(attempt, final_operation) if final_operation else attempt + + _validate_timeouts(final_operation, final_attempt, allow_none=False) + return final_operation, final_attempt + + def _validate_timeouts( operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False ): @@ -128,6 +200,8 @@ def _validate_timeouts( Raises: - ValueError if operation_timeout or attempt_timeout are invalid. """ + if operation_timeout is None: + raise ValueError("operation_timeout cannot be None") if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") if not allow_none and attempt_timeout is None: diff --git a/noxfile.py b/noxfile.py index 4b57e617f..e1d2f4acc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -460,7 +460,7 @@ def prerelease_deps(session): # Exclude version 1.52.0rc1 which has a known issue. See https://github.com/grpc/grpc/issues/32163 "grpcio!=1.52.0rc1", "grpcio-status", - "google-api-core", + "google-api-core==2.12.0.dev1", # TODO: remove this once streaming retries is merged "proto-plus", "google-cloud-testutils", # dependencies of google-cloud-testutils" diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 08bc397c3..6c11fa86a 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -14,6 +14,7 @@ import pytest import google.cloud.bigtable.data._helpers as _helpers +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT import google.cloud.bigtable.data.exceptions as bigtable_exceptions import mock @@ -199,3 +200,67 @@ def test_validate_with_inputs(self, args, expected): except ValueError: pass assert success == expected + + +class TestGetTimeouts: + @pytest.mark.parametrize( + "input_times,input_table,expected", + [ + ((2, 1), {}, (2, 1)), + ((2, 4), {}, (2, 2)), + ((2, None), {}, (2, 2)), + ( + (TABLE_DEFAULT.DEFAULT, TABLE_DEFAULT.DEFAULT), + {"operation": 3, "attempt": 2}, + (3, 2), + ), + ( + (TABLE_DEFAULT.READ_ROWS, TABLE_DEFAULT.READ_ROWS), + {"read_rows_operation": 3, "read_rows_attempt": 2}, + (3, 2), + ), + ( + (TABLE_DEFAULT.MUTATE_ROWS, TABLE_DEFAULT.MUTATE_ROWS), + {"mutate_rows_operation": 3, "mutate_rows_attempt": 2}, + (3, 2), + ), + ((10, TABLE_DEFAULT.DEFAULT), {"attempt": None}, (10, 10)), + ((10, TABLE_DEFAULT.DEFAULT), {"attempt": 5}, (10, 5)), + ((10, TABLE_DEFAULT.DEFAULT), {"attempt": 100}, (10, 10)), + ((TABLE_DEFAULT.DEFAULT, 10), {"operation": 12}, (12, 10)), + ((TABLE_DEFAULT.DEFAULT, 10), {"operation": 3}, (3, 3)), + ], + ) + def test_get_timeouts(self, input_times, input_table, expected): + """ + test input/output mappings for a variety of valid inputs + """ + fake_table = mock.Mock() + for key in input_table.keys(): + # set the default fields in our fake table mock + setattr(fake_table, f"default_{key}_timeout", input_table[key]) + t1, t2 = _helpers._get_timeouts(input_times[0], input_times[1], fake_table) + assert t1 == expected[0] + assert t2 == expected[1] + + @pytest.mark.parametrize( + "input_times,input_table", + [ + ([0, 1], {}), + ([1, 0], {}), + ([None, 1], {}), + ([TABLE_DEFAULT.DEFAULT, 1], {"operation": None}), + ([TABLE_DEFAULT.DEFAULT, 1], {"operation": 0}), + ([1, TABLE_DEFAULT.DEFAULT], {"attempt": 0}), + ], + ) + def test_get_timeouts_invalid(self, input_times, input_table): + """ + test with inputs that should raise error during validation step + """ + fake_table = mock.Mock() + for key in input_table.keys(): + # set the default fields in our fake table mock + setattr(fake_table, f"default_{key}_timeout", input_table[key]) + with pytest.raises(ValueError): + _helpers._get_timeouts(input_times[0], input_times[1], fake_table)