diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 7d1144553..778aecb74 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -25,6 +25,7 @@ from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import backoff_generator # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -35,6 +36,7 @@ ) from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._metrics import ActiveOperationMetric @dataclass @@ -65,6 +67,7 @@ def __init__( mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, + metrics: ActiveOperationMetric, retryable_exceptions: Sequence[type[Exception]] = (), ): """ @@ -75,6 +78,8 @@ def __init__( - operation_timeout: the timeout to use for the entire operation, in seconds. - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. If not specified, the request will run until operation_timeout is reached. + - metrics: the metrics object to use for tracking the operation + - retryable_exceptions: a list of exceptions that should be retried """ # check that mutations are within limits total_mutations = sum(len(entry.mutations) for entry in mutation_entries) @@ -100,7 +105,7 @@ def __init__( # Entry level errors bt_exceptions._MutateRowsIncomplete, ) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + sleep_generator = backoff_generator(0.01, 2, 60) self._operation = retries.retry_target_async( self._run_attempt, self.is_retryable, @@ -115,6 +120,9 @@ def __init__( self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} + # set up metrics + metrics.backoff_generator = sleep_generator + self._operation_metrics = metrics async def start(self): """ @@ -148,9 +156,13 @@ async def start(self): bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) ) if all_errors: - raise bt_exceptions.MutationsExceptionGroup( + combined_exc = bt_exceptions.MutationsExceptionGroup( all_errors, len(self.mutations) ) + self._operation_metrics.end_with_status(combined_exc) + raise combined_exc + else: + self._operation_metrics.end_with_success() async def _run_attempt(self): """ @@ -161,6 +173,8 @@ async def _run_attempt(self): retry after the attempt is complete - GoogleAPICallError: if the gapic rpc fails """ + # register attempt start + self._operation_metrics.start_attempt() request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] # track mutations in this request that have not been finalized yet active_request_indices = { @@ -177,34 +191,47 @@ async def _run_attempt(self): entries=request_entries, retry=None, ) - async for result_list in result_generator: - for result in result_list.entries: - # convert sub-request index to global index - orig_idx = active_request_indices[result.index] - entry_error = core_exceptions.from_grpc_status( - result.status.code, - result.status.message, - details=result.status.details, - ) - if result.status.code != 0: - # mutation failed; update error list (and remaining_indices if retryable) - self._handle_entry_error(orig_idx, entry_error) - elif orig_idx in self.errors: - # mutation succeeded; remove from error list - del self.errors[orig_idx] - # remove processed entry from active list - del active_request_indices[result.index] + try: + async for result_list in result_generator: + for result in result_list.entries: + # convert sub-request index to global index + orig_idx = active_request_indices[result.index] + entry_error = core_exceptions.from_grpc_status( + result.status.code, + result.status.message, + details=result.status.details, + ) + if result.status.code != 0: + # mutation failed; update error list (and remaining_indices if retryable) + self._handle_entry_error(orig_idx, entry_error) + elif orig_idx in self.errors: + # mutation succeeded; remove from error list + del self.errors[orig_idx] + # remove processed entry from active list + del active_request_indices[result.index] + finally: + # send trailing metadata to metrics + result_generator.cancel() + metadata = ( + await result_generator.trailing_metadata() + + await result_generator.initial_metadata() + ) + self._operation_metrics.add_response_metadata(metadata) except Exception as exc: # add this exception to list for each mutation that wasn't # already handled, and update remaining_indices if mutation is retryable for idx in active_request_indices.values(): self._handle_entry_error(idx, exc) + # record attempt failure metric + self._operation_metrics.end_attempt_with_status(exc) # bubble up exception to be handled by retry wrapper raise # check if attempt succeeded, or needs to be retried if self.remaining_indices: # unfinished work; raise exception to trigger retry - raise bt_exceptions._MutateRowsIncomplete + last_exc = self.errors[self.remaining_indices[-1]][-1] + self._operation_metrics.end_attempt_with_status(last_exc) + raise bt_exceptions._MutateRowsIncomplete() def _handle_entry_error(self, idx: int, exc: Exception): """ diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 9e0fd78e1..065490d3c 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -18,10 +18,10 @@ from typing import ( TYPE_CHECKING, AsyncGenerator, - AsyncIterable, Awaitable, Sequence, ) +import time from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -34,13 +34,16 @@ from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import backoff_generator + +from google.api_core.grpc_helpers_async import GrpcAsyncStream from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.api_core import retry as retries -from google.api_core.retry import exponential_sleep_generator if TYPE_CHECKING: from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._metrics import ActiveOperationMetric class _ResetRow(Exception): @@ -70,6 +73,7 @@ class _ReadRowsOperationAsync: "_metadata", "_last_yielded_row_key", "_remaining_count", + "_operation_metrics", ) def __init__( @@ -78,6 +82,7 @@ def __init__( table: "TableAsync", operation_timeout: float, attempt_timeout: float, + metrics: ActiveOperationMetric, retryable_exceptions: Sequence[type[Exception]] = (), ): self.attempt_timeout_gen = _attempt_timeout_generator( @@ -100,15 +105,21 @@ def __init__( ) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None + self._operation_metrics = metrics def start_operation(self) -> AsyncGenerator[Row, None]: """ Start the read_rows operation, retrying on retryable errors. """ + self._operation_metrics.start() + + sleep_generator = backoff_generator() + self._operation_metrics.backoff_generator = sleep_generator + return retries.retry_target_stream_async( self._read_rows_attempt, - self._predicate, - exponential_sleep_generator(0.01, 60, multiplier=2), + self._operation_metrics.build_wrapped_predicate(self._predicate), + sleep_generator, self.operation_timeout, exception_factory=_retry_exception_factory, ) @@ -120,6 +131,8 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: which will call this function until it succeeds or a non-retryable error is raised. """ + # register metric start + self._operation_metrics.start_attempt() # revise request keys and ranges between attempts if self._last_yielded_row_key is not None: # if this is a retry, try to trim down the request to avoid ones we've already processed @@ -130,12 +143,12 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: ) except _RowSetComplete: # if we've already seen all the rows, we're done - return self.merge_rows(None) + return self.merge_rows(None, self._operation_metrics) # revise the limit based on number of rows already yielded if self._remaining_count is not None: self.request.rows_limit = self._remaining_count if self._remaining_count == 0: - return self.merge_rows(None) + return self.merge_rows(None, self._operation_metrics) # create and return a new row merger gapic_stream = self.table.client._gapic_client.read_rows( self.request, @@ -144,70 +157,82 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: retry=None, ) chunked_stream = self.chunk_stream(gapic_stream) - return self.merge_rows(chunked_stream) + return self.merge_rows(chunked_stream, self._operation_metrics) async def chunk_stream( - self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] + self, stream: Awaitable[GrpcAsyncStream[ReadRowsResponsePB]] ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]: """ process chunks out of raw read_rows stream """ - async for resp in await stream: - # extract proto from proto-plus wrapper - resp = resp._pb + call = await stream + try: + async for resp in call: + # extract proto from proto-plus wrapper + resp = resp._pb + + # handle last_scanned_row_key packets, sent when server + # has scanned past the end of the row range + if resp.last_scanned_row_key: + if ( + self._last_yielded_row_key is not None + and resp.last_scanned_row_key <= self._last_yielded_row_key + ): + raise InvalidChunk("last scanned out of order") + self._last_yielded_row_key = resp.last_scanned_row_key - # handle last_scanned_row_key packets, sent when server - # has scanned past the end of the row range - if resp.last_scanned_row_key: - if ( - self._last_yielded_row_key is not None - and resp.last_scanned_row_key <= self._last_yielded_row_key - ): - raise InvalidChunk("last scanned out of order") - self._last_yielded_row_key = resp.last_scanned_row_key - - current_key = None - # process each chunk in the response - for c in resp.chunks: - if current_key is None: - current_key = c.row_key + current_key = None + # process each chunk in the response + for c in resp.chunks: if current_key is None: - raise InvalidChunk("first chunk is missing a row key") - elif ( - self._last_yielded_row_key - and current_key <= self._last_yielded_row_key - ): - raise InvalidChunk("row keys should be strictly increasing") + current_key = c.row_key + if current_key is None: + raise InvalidChunk("first chunk is missing a row key") + elif ( + self._last_yielded_row_key + and current_key <= self._last_yielded_row_key + ): + raise InvalidChunk("row keys should be strictly increasing") - yield c + yield c - if c.reset_row: - current_key = None - elif c.commit_row: - # update row state after each commit - self._last_yielded_row_key = current_key - if self._remaining_count is not None: - self._remaining_count -= 1 - if self._remaining_count < 0: - raise InvalidChunk("emit count exceeds row limit") - current_key = None + if c.reset_row: + current_key = None + elif c.commit_row: + # update row state after each commit + self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") + current_key = None + finally: + # ensure stream is closed + call.cancel() + # send trailing metadata to metrics + metadata = await call.trailing_metadata() + await call.initial_metadata() + self._operation_metrics.add_response_metadata(metadata) @staticmethod async def merge_rows( - chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None + chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None, + operation: ActiveOperationMetric, ): """ Merge chunks into rows """ if chunks is None: + operation.end_with_success() return it = chunks.__aiter__() + is_first_row = True # For each row while True: try: c = await it.__anext__() except StopAsyncIteration: # stream complete + operation.end_with_success() return row_key = c.row_key @@ -284,7 +309,17 @@ async def merge_rows( Cell(value, row_key, family, qualifier, ts, list(labels)) ) if c.commit_row: + if is_first_row: + # record first row latency in metrics + is_first_row = False + operation.attempt_first_response() + block_time = time.monotonic() yield Row(row_key, cells) + # most metric operations use setters, but this one updates + # the value directly to avoid extra overhead + operation.active_attempt.application_blocking_time += ( # type: ignore + time.monotonic() - block_time + ) break c = await it.__anext__() except _ResetRow as e: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 11a4ed7bd..404611619 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -70,6 +70,7 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import backoff_generator from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule @@ -79,6 +80,7 @@ from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController +from google.cloud.bigtable.data._metrics import OperationType if TYPE_CHECKING: @@ -533,7 +535,6 @@ def __init__( table_id=table_id, app_profile_id=app_profile_id, ) - self.default_read_rows_retryable_errors = ( default_read_rows_retryable_errors or () ) @@ -560,6 +561,7 @@ async def read_rows_stream( attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + **kwargs, ) -> AsyncIterable[Row]: """ Read a set of rows from the table, based on the specified query. @@ -584,7 +586,7 @@ async def read_rows_stream( - an asynchronous iterator that yields rows returned by the query Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - GoogleAPIError: raised if the request encounters an unrecoverable error """ @@ -593,11 +595,20 @@ async def read_rows_stream( ) retryable_excs = _get_retryable_errors(retryable_errors, self) + # extract metric operation if passed down through kwargs + # used so that read_row can disable is_streaming flag + metric_operation = kwargs.pop("metric_operation", None) + if metric_operation is None: + metric_operation = self._metrics.create_operation( + OperationType.READ_ROWS, is_streaming=True + ) + row_merger = _ReadRowsOperationAsync( query, self, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + metrics=metric_operation, retryable_exceptions=retryable_excs, ) return row_merger.start_operation() @@ -610,6 +621,7 @@ async def read_rows( attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + **kwargs, ) -> list[Row]: """ Read a set of rows from the table, based on the specified query. @@ -637,15 +649,16 @@ async def read_rows( - a list of Rows returned by the query Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ row_generator = await self.read_rows_stream( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, retryable_errors=retryable_errors, + **kwargs, ) return [row async for row in row_generator] @@ -681,17 +694,21 @@ async def read_row( - a Row object if the row exists, otherwise None Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ if row_key is None: raise ValueError("row_key must be string or bytes") + metric_operation = self._metrics.create_operation( + OperationType.READ_ROWS, is_streaming=False + ) query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) results = await self.read_rows( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + metric_operation=metric_operation, retryable_errors=retryable_errors, ) if len(results) == 0: @@ -751,8 +768,8 @@ async def read_rows_sharded( for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) ] # run batches and collect results - results_list = [] - error_dict = {} + results_list: list[Row] = [] + error_dict: dict[int, Exception] = {} shard_idx = 0 for batch in batched_queries: batch_operation_timeout = next(timeout_generator) @@ -816,9 +833,9 @@ async def row_exists( - a bool indicating whether the row exists Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ if row_key is None: raise ValueError("row_key must be string or bytes") @@ -827,10 +844,14 @@ async def row_exists( limit_filter = CellsRowLimitFilter(1) chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + metric_operation = self._metrics.create_operation( + OperationType.READ_ROWS, is_streaming=False + ) results = await self.read_rows( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + metric_operation=metric_operation, retryable_errors=retryable_errors, ) return len(results) > 0 @@ -869,9 +890,9 @@ async def sample_row_keys( - a set of RowKeySamples the delimit contiguous sections of the table Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ # prepare timeouts operation_timeout, attempt_timeout = _get_timeouts( @@ -884,28 +905,43 @@ async def sample_row_keys( retryable_excs = _get_retryable_errors(retryable_errors, self) predicate = retries.if_exception_type(*retryable_excs) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + sleep_generator = backoff_generator() # prepare request metadata = _make_metadata(self.table_name, self.app_profile_id) - async def execute_rpc(): - results = await self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, - ) - return [(s.row_key, s.offset_bytes) async for s in results] + # wrap rpc in retry and metric collection logic + async with self._metrics.create_operation( + OperationType.SAMPLE_ROW_KEYS, backoff_generator=sleep_generator + ) as operation: + + async def execute_rpc(): + stream = await self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + samples = [(s.row_key, s.offset_bytes) async for s in stream] + # send metadata to metric collector + call_metadata = ( + await stream.trailing_metadata() + await stream.initial_metadata() + ) + operation.add_response_metadata(call_metadata) + # return results + return samples - return await retries.retry_target_async( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + metric_wrapped = operation.wrap_attempt_fn( + execute_rpc, extract_call_metadata=False + ) + return await retries.retry_target_async( + metric_wrapped, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) def mutations_batcher( self, @@ -992,8 +1028,8 @@ async def mutate_row( Raises: - DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing all - GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised on non-idempotent operations that cannot be + GoogleAPICallError exceptions from any retries that failed + - GoogleAPICallError: raised on non-idempotent operations that cannot be safely retried. - ValueError if invalid arguments are provided """ @@ -1014,25 +1050,34 @@ async def mutate_row( # mutations should not be retried predicate = retries.if_exception_type() - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - - target = partial( - self.client._gapic_client.mutate_row, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - mutations=[mutation._to_pb() for mutation in mutations_list], - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=attempt_timeout, - metadata=_make_metadata(self.table_name, self.app_profile_id), - retry=None, - ) - return await retries.retry_target_async( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + sleep_generator = backoff_generator() + + # wrap rpc in retry and metric collection logic + async with self._metrics.create_operation( + OperationType.MUTATE_ROW, backoff_generator=sleep_generator + ) as operation: + metric_wrapped = operation.wrap_attempt_fn( + self.client._gapic_client.mutate_row + ) + target = partial( + metric_wrapped, + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + metadata=_make_metadata(self.table_name, self.app_profile_id), + retry=None, + ) + return await retries.retry_target_async( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) async def bulk_mutate_rows( self, @@ -1085,6 +1130,7 @@ async def bulk_mutate_rows( mutation_entries, operation_timeout, attempt_timeout, + self._metrics.create_operation(OperationType.BULK_MUTATE_ROWS), retryable_exceptions=retryable_excs, ) await operation.start() @@ -1128,7 +1174,7 @@ async def check_and_mutate_row( Returns: - bool indicating whether the predicate was true or false Raises: - - GoogleAPIError exceptions from grpc call + - GoogleAPICallError exceptions from grpc call """ operation_timeout, _ = _get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and not isinstance( @@ -1142,18 +1188,25 @@ async def check_and_mutate_row( false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - return result.predicate_matched + + async with self._metrics.create_operation( + OperationType.CHECK_AND_MUTATE + ) as operation: + metric_wrapped = operation.wrap_attempt_fn( + self.client._gapic_client.check_and_mutate_row + ) + result = await metric_wrapped( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode() if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return result.predicate_matched async def read_modify_write_row( self, @@ -1183,7 +1236,7 @@ async def read_modify_write_row( - Row: containing cell data that was modified as part of the operation Raises: - - GoogleAPIError exceptions from grpc call + - GoogleAPICallError exceptions from grpc call - ValueError if invalid arguments are provided """ operation_timeout, _ = _get_timeouts(operation_timeout, None, self) @@ -1194,17 +1247,25 @@ async def read_modify_write_row( if not rules: raise ValueError("rules must contain at least one item") metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - # construct Row from result - return Row._from_pb(result.row) + + async with self._metrics.create_operation( + OperationType.READ_MODIFY_WRITE + ) as operation: + metric_wrapped = operation.wrap_attempt_fn( + self.client._gapic_client.read_modify_write_row + ) + + result = await metric_wrapped( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode() if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + # construct Row from result + return Row._from_pb(result.row) async def close(self): """ diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 5d5dd535e..dbf3102a9 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -17,6 +17,7 @@ from typing import Any, Sequence, TYPE_CHECKING import asyncio import atexit +import time import warnings from collections import deque @@ -32,6 +33,8 @@ _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._metrics import OperationType +from google.cloud.bigtable.data._metrics import ActiveOperationMetric if TYPE_CHECKING: from google.cloud.bigtable.data._async.client import TableAsync @@ -328,9 +331,18 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ # flush new entries in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + metric = self._table._metrics.create_operation(OperationType.BULK_MUTATE_ROWS) + flow_start_time = time.monotonic() async for batch in self._flow_control.add_to_flow(new_entries): - batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + # add time waiting on flow control to throttling metric + metric.flow_throttling_time = time.monotonic() - flow_start_time + batch_task = self._create_bg_task(self._execute_mutate_rows, batch, metric) in_process_requests.append(batch_task) + # start a new metric for next batch + metric = self._table._metrics.create_operation( + OperationType.BULK_MUTATE_ROWS + ) + flow_start_time = time.monotonic() # wait for all inflight requests to complete found_exceptions = await self._wait_for_batch_results(*in_process_requests) # update exception data to reflect any new errors @@ -338,7 +350,7 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._add_exceptions(found_exceptions) async def _execute_mutate_rows( - self, batch: list[RowMutationEntry] + self, batch: list[RowMutationEntry], metrics: ActiveOperationMetric ) -> list[FailedMutationEntryError]: """ Helper to execute mutation operation on a batch @@ -358,6 +370,7 @@ async def _execute_mutate_rows( batch, operation_timeout=self._operation_timeout, attempt_timeout=self._attempt_timeout, + metrics=metrics, retryable_exceptions=self._retryable_errors, ) await operation.start() @@ -487,10 +500,10 @@ async def _wait_for_batch_results( found_errors = [] for result in all_results: if isinstance(result, Exception): - # will receive direct Exception objects if request task fails + # will receive Exception objects if request task fails. Add to list found_errors.append(result) elif isinstance(result, BaseException): - # BaseException not expected from grpc calls. Raise immediately + # BaseException won't be encountered in normal execution. Raise immediately raise result elif result: # completed requests will return a list of FailedMutationEntryError diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index a0b13cbaf..0060501ba 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -23,6 +23,7 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.api_core import exceptions as core_exceptions +from google.api_core.retry import exponential_sleep_generator from google.api_core.retry import RetryFailureReason from google.cloud.bigtable.data.exceptions import RetryExceptionGroup @@ -97,6 +98,25 @@ def _attempt_timeout_generator( yield max(0, min(per_request_timeout, deadline - time.monotonic())) +def backoff_generator(initial=0.01, multiplier=2, maximum=60): + """ + Build a generator for exponential backoff sleep times. + + This implementation builds on top of api_core.retries.exponential_sleep_generator, + adding the ability to retrieve previous values using the send(idx) method. This is + used by the Metrics class to track the sleep times used for each attempt. + """ + history = [] + subgenerator = exponential_sleep_generator(initial, multiplier, maximum) + while True: + next_backoff = next(subgenerator) + history.append(next_backoff) + sent_idx = yield next_backoff + while sent_idx is not None: + # requesting from history + sent_idx = yield history[sent_idx] + + def _retry_exception_factory( exc_list: list[Exception], reason: RetryFailureReason, @@ -117,7 +137,7 @@ def _retry_exception_factory( timeout_val_str = f"of {timeout_val:0.1f}s " if timeout_val is not None else "" # if failed due to timeout, raise deadline exceeded as primary exception source_exc: Exception = core_exceptions.DeadlineExceeded( - f"operation_timeout{timeout_val_str} exceeded" + f"operation_timeout {timeout_val_str}exceeded" ) elif exc_list: # otherwise, raise non-retryable error as primary exception diff --git a/google/cloud/bigtable_v2/services/bigtable/async_client.py b/google/cloud/bigtable_v2/services/bigtable/async_client.py index e497ff25b..3146ad4a9 100644 --- a/google/cloud/bigtable_v2/services/bigtable/async_client.py +++ b/google/cloud/bigtable_v2/services/bigtable/async_client.py @@ -37,6 +37,7 @@ from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries +from google.api_core.grpc_helpers_async import GrpcAsyncStream from google.auth import credentials as ga_credentials # type: ignore from google.oauth2 import service_account # type: ignore @@ -222,7 +223,7 @@ def read_rows( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.ReadRowsResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.ReadRowsResponse]]: r"""Streams back the contents of all requested rows in key order, optionally applying the same Reader filter to each. Depending on their size, rows and cells may be @@ -316,7 +317,7 @@ def sample_row_keys( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.SampleRowKeysResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.SampleRowKeysResponse]]: r"""Returns a sample of row keys in the table. The returned row keys will delimit contiguous sections of the table of approximately equal size, which can be used @@ -400,7 +401,7 @@ def sample_row_keys( # Done; return the response. return response - async def mutate_row( + def mutate_row( self, request: Optional[Union[bigtable.MutateRowRequest, dict]] = None, *, @@ -504,17 +505,14 @@ async def mutate_row( ), ) - # Send the request. - response = await rpc( + # Return the grpc call coroutine + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - def mutate_rows( self, request: Optional[Union[bigtable.MutateRowsRequest, dict]] = None, @@ -525,7 +523,7 @@ def mutate_rows( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.MutateRowsResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.MutateRowsResponse]]: r"""Mutates multiple rows in a batch. Each individual row is mutated atomically as in MutateRow, but the entire batch is not executed atomically. @@ -624,7 +622,7 @@ def mutate_rows( # Done; return the response. return response - async def check_and_mutate_row( + def check_and_mutate_row( self, request: Optional[Union[bigtable.CheckAndMutateRowRequest, dict]] = None, *, @@ -766,18 +764,15 @@ async def check_and_mutate_row( ), ) - # Send the request. - response = await rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - - async def ping_and_warm( + def ping_and_warm( self, request: Optional[Union[bigtable.PingAndWarmRequest, dict]] = None, *, @@ -857,18 +852,15 @@ async def ping_and_warm( gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) - # Send the request. - response = await rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - - async def read_modify_write_row( + def read_modify_write_row( self, request: Optional[Union[bigtable.ReadModifyWriteRowRequest, dict]] = None, *, @@ -979,17 +971,14 @@ async def read_modify_write_row( ), ) - # Send the request. - response = await rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - def generate_initial_change_stream_partitions( self, request: Optional[ @@ -1002,7 +991,7 @@ def generate_initial_change_stream_partitions( timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), ) -> Awaitable[ - AsyncIterable[bigtable.GenerateInitialChangeStreamPartitionsResponse] + GrpcAsyncStream[bigtable.GenerateInitialChangeStreamPartitionsResponse] ]: r"""NOTE: This API is intended to be used by Apache Beam BigtableIO. Returns the current list of partitions that make up the table's @@ -1086,17 +1075,14 @@ def generate_initial_change_stream_partitions( ), ) - # Send the request. - response = rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - def read_change_stream( self, request: Optional[Union[bigtable.ReadChangeStreamRequest, dict]] = None, @@ -1106,7 +1092,7 @@ def read_change_stream( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.ReadChangeStreamResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.ReadChangeStreamResponse]]: r"""NOTE: This API is intended to be used by Apache Beam BigtableIO. Reads changes from a table's change stream. Changes will reflect both user-initiated mutations and diff --git a/mypy.ini b/mypy.ini index 31cc24223..1c755808e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -26,3 +26,6 @@ ignore_missing_imports = True [mypy-pytest] ignore_missing_imports = True + +[mypy-google.api.*] +ignore_missing_imports = True diff --git a/tests/unit/data/_async/__init__.py b/tests/unit/data/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index e03028c45..ffee14087 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -18,6 +18,8 @@ from google.rpc import status_pb2 import google.api_core.exceptions as core_exceptions +from .test_client import mock_grpc_call + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -48,6 +50,7 @@ def _make_one(self, *args, **kwargs): kwargs["table"] = kwargs.pop("table", AsyncMock()) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["metrics"] = kwargs.pop("metrics", mock.Mock()) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) return self._target_class()(*args, **kwargs) @@ -67,9 +70,17 @@ def _make_mock_gapic(self, mutation_list, error_dict=None): mock_fn = AsyncMock() if error_dict is None: error_dict = {} - mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( - mutation_list, error_dict - ) + responses = [ + MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=error_dict.get(idx, 0)) + ) + ] + ) + for idx, _ in enumerate(mutation_list) + ] + mock_fn.return_value = mock_grpc_call(stream_response=responses) return mock_fn def test_ctor(self): @@ -86,6 +97,7 @@ def test_ctor(self): entries = [_make_mutation(), _make_mutation()] operation_timeout = 0.05 attempt_timeout = 0.01 + metrics = mock.Mock() retryable_exceptions = () instance = self._make_one( client, @@ -93,6 +105,7 @@ def test_ctor(self): entries, operation_timeout, attempt_timeout, + metrics, retryable_exceptions, ) # running gapic_fn should trigger a client call @@ -123,6 +136,7 @@ def test_ctor(self): assert instance.is_retryable(RuntimeError("")) is False assert instance.remaining_indices == list(range(len(entries))) assert instance.errors == {} + assert instance._operation_metrics == metrics def test_ctor_too_many_entries(self): """ @@ -139,8 +153,11 @@ def test_ctor_too_many_entries(self): entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT operation_timeout = 0.05 attempt_timeout = 0.01 + metrics = mock.Mock() # no errors if at limit - self._make_one(client, table, entries, operation_timeout, attempt_timeout) + self._make_one( + client, table, entries, operation_timeout, attempt_timeout, metrics + ) # raise error after crossing with pytest.raises(ValueError) as e: self._make_one( @@ -149,6 +166,7 @@ def test_ctor_too_many_entries(self): entries + [_make_mutation()], operation_timeout, attempt_timeout, + metrics, ) assert "mutate_rows requests can contain at most 100000 mutations" in str( e.value @@ -169,7 +187,12 @@ async def test_mutate_rows_operation(self): f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() ) as attempt_mock: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance.start() assert attempt_mock.call_count == 1 @@ -191,7 +214,12 @@ async def test_mutate_rows_attempt_exception(self, exc_type): found_exc = None try: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance._run_attempt() except Exception as e: @@ -227,7 +255,12 @@ async def test_mutate_rows_exception(self, exc_type): found_exc = None try: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance.start() except MutationsExceptionGroup as e: @@ -270,6 +303,7 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): entries, operation_timeout, operation_timeout, + mock.Mock(), retryable_exceptions=(exc_type,), ) await instance.start() @@ -294,17 +328,19 @@ async def test_mutate_rows_incomplete_ignored(self): AsyncMock(), ) as attempt_mock: attempt_mock.side_effect = _MutateRowsIncomplete("ignored") - found_exc = None - try: + with pytest.raises(MutationsExceptionGroup) as e: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance.start() - except MutationsExceptionGroup as e: - found_exc = e assert attempt_mock.call_count > 0 - assert len(found_exc.exceptions) == 1 - assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) + assert len(e.value.exceptions) == 1 + assert isinstance(e.value.exceptions[0].__cause__, DeadlineExceeded) @pytest.mark.asyncio async def test_run_attempt_single_entry_success(self): diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 4e7797c6d..31d6161d7 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -14,6 +14,7 @@ import pytest from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from .test_client import mock_grpc_call # try/except added for compatibility with python < 3.8 try: @@ -60,6 +61,7 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() + metrics = mock.Mock() with mock.patch( "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator", time_gen_mock, @@ -69,6 +71,7 @@ def test_ctor(self): table, operation_timeout=expected_operation_timeout, attempt_timeout=expected_request_timeout, + metrics=metrics, ) assert time_gen_mock.call_count == 1 time_gen_mock.assert_called_once_with( @@ -87,6 +90,7 @@ def test_ctor(self): assert instance.request.table_name == table.table_name assert instance.request.app_profile_id == table.app_profile_id assert instance.request.rows_limit == row_limit + assert instance._operation_metrics == metrics @pytest.mark.parametrize( "in_keys,last_key,expected", @@ -228,31 +232,27 @@ async def test_revise_limit(self, start_limit, emit_num, expected_limit): from google.cloud.bigtable.data import ReadRowsQuery from google.cloud.bigtable_v2.types import ReadRowsResponse - async def awaitable_stream(): - async def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - query = ReadRowsQuery(limit=start_limit) table = mock.Mock() table.table_name = "table_name" table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) + instance = self._make_one(query, table, 10, 10, mock.Mock()) assert instance._remaining_count == start_limit # read emit_num rows - async for val in instance.chunk_stream(awaitable_stream()): + chunks = [ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + for i in range(emit_num) + ] + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in chunks] + ) + async for val in instance.chunk_stream(stream): pass assert instance._remaining_count == expected_limit @@ -267,32 +267,28 @@ async def test_revise_limit_over_limit(self, start_limit, emit_num): from google.cloud.bigtable_v2.types import ReadRowsResponse from google.cloud.bigtable.data.exceptions import InvalidChunk - async def awaitable_stream(): - async def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - query = ReadRowsQuery(limit=start_limit) table = mock.Mock() table.table_name = "table_name" table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) + instance = self._make_one(query, table, 10, 10, mock.Mock()) assert instance._remaining_count == start_limit with pytest.raises(InvalidChunk) as e: # read emit_num rows - async for val in instance.chunk_stream(awaitable_stream()): + chunks = [ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + for i in range(emit_num) + ] + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in chunks] + ) + async for val in instance.chunk_stream(stream): pass assert "emit count exceeds row limit" in str(e.value) @@ -302,17 +298,12 @@ async def test_aclose(self): should be able to close a stream safely with aclose. Closed generators should raise StopAsyncIteration on next yield """ - - async def mock_stream(): - while True: - yield 1 - with mock.patch.object( _ReadRowsOperationAsync, "_read_rows_attempt" ) as mock_attempt: - instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) - wrapped_gen = mock_stream() - mock_attempt.return_value = wrapped_gen + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1, mock.Mock()) + call = mock_grpc_call(stream_response=range(100)) + mock_attempt.return_value = call gen = instance.start_operation() # read one row await gen.__anext__() @@ -323,7 +314,7 @@ async def mock_stream(): await gen.aclose() # ensure close was propagated to wrapped generator with pytest.raises(StopAsyncIteration): - await wrapped_gen.__anext__() + await call.__anext__() @pytest.mark.asyncio async def test_retryable_ignore_repeated_rows(self): @@ -336,26 +327,14 @@ async def test_retryable_ignore_repeated_rows(self): row_key = b"duplicate" - async def mock_awaitable_stream(): - async def mock_stream(): - while True: - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - - return mock_stream() - instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) + chunks = [ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True)] * 2 + grpc_call = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in chunks] + ) + stream = _ReadRowsOperationAsync.chunk_stream(instance, grpc_call) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: await stream.__anext__() diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 9a12abe9b..5815097da 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -45,6 +45,70 @@ ) +class mock_grpc_call: + """ + Used for mocking the responses from grpc calls. Can simulate both unary and streaming calls. + """ + + def __init__( + self, + unary_response=None, + stream_response=(), + sleep_time=0, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + ): + self.unary_response = unary_response + self.stream_response = stream_response + self.sleep_time = sleep_time + self.stream_idx = -1 + self._future = asyncio.get_event_loop().create_future() + self._future.set_result(unary_response) + self._initial_metadata = initial_metadata + self._trailing_metadata = trailing_metadata + + def __await__(self): + response = yield from self._future.__await__() + if response is None: + # await is a no-op for streaming calls + return self + # otherwise return unary response + return response + + def __aiter__(self): + return self + + async def __anext__(self): + self.stream_idx += 1 + if self.stream_idx < len(self.stream_response): + await asyncio.sleep(self.sleep_time) + next_val = self.stream_response[self.stream_idx] + if isinstance(next_val, Exception): + raise next_val + return next_val + raise StopAsyncIteration + + def cancel(self): + pass + + async def asend(self, val): + """ + implement generator protocol, so retries will treat this as a generator + i.e, call aclose at end of stream + """ + return await self.__anext__() + + async def aclose(self): + # simulate closing streams by jumping to the end + self.stream_idx = len(self.stream_response) + + async def trailing_metadata(self): + return self._trailing_metadata + + async def initial_metadata(self): + return self._initial_metadata + + class TestBigtableDataClientAsync: def _get_target_class(self): from google.cloud.bigtable.data._async.client import BigtableDataClientAsync @@ -1004,6 +1068,10 @@ async def test_table_ctor(self): from google.cloud.bigtable.data._async.client import BigtableDataClientAsync from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._metrics import ( + BigtableClientSideMetricsController, + ) + from google.cloud.bigtable.data._metrics import OpenTelemetryMetricsHandler expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1057,6 +1125,10 @@ async def test_table_ctor(self): table.default_mutate_rows_attempt_timeout == expected_mutate_rows_attempt_timeout ) + # check metrics object + assert isinstance(table._metrics, BigtableClientSideMetricsController) + assert len(table._metrics.handlers) == 1 + assert isinstance(table._metrics.handlers[0], OpenTelemetryMetricsHandler) # ensure task reaches completion await table._register_instance_task assert table._register_instance_task.done() @@ -1226,7 +1298,8 @@ async def test_customizable_retryable_errors( with mock.patch(retry_fn_path) as retry_fn_mock: async with BigtableDataClientAsync() as client: table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = mock.Mock() + expected_predicate.side_effect = lambda exc: exc in expected_retryables retry_fn_mock.side_effect = RuntimeError("stop early") with mock.patch( "google.api_core.retry.if_exception_type" @@ -1242,7 +1315,13 @@ async def test_customizable_retryable_errors( ) retry_call_args = retry_fn_mock.call_args_list[0].args # output of if_exception_type should be sent in to retry constructor - assert retry_call_args[1] is expected_predicate + # note: may be wrapped by metrics + assert expected_predicate.call_count == 0 + found_predicate = retry_call_args[1] + obj = RuntimeError("test") + found_predicate(obj) + assert expected_predicate.call_count == 1 + assert expected_predicate.called_with(obj) @pytest.mark.parametrize( "fn_name,fn_args,gapic_fn", @@ -1330,6 +1409,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + client_mock.project = "test-project" return TableAsync(client_mock, *args, **kwargs) def _make_stats(self): @@ -1367,31 +1447,11 @@ async def _make_gapic_stream( ): from google.cloud.bigtable_v2 import ReadRowsResponse - class mock_stream: - def __init__(self, chunk_list, sleep_time): - self.chunk_list = chunk_list - self.idx = -1 - self.sleep_time = sleep_time - - def __aiter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - if sleep_time: - await asyncio.sleep(self.sleep_time) - chunk = self.chunk_list[self.idx] - if isinstance(chunk, Exception): - raise chunk - else: - return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration - - def cancel(self): - pass - - return mock_stream(chunk_list, sleep_time) + pb_list = [ + c if isinstance(c, Exception) else ReadRowsResponse(chunks=[c]) + for c in chunk_list + ] + return mock_grpc_call(stream_response=pb_list, sleep_time=sleep_time) async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) @@ -1464,17 +1524,16 @@ async def test_read_rows_timeout(self, operation_timeout): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows query = ReadRowsQuery() - chunks = [self._make_chunk(row_key=b"test_1")] + chunks = [core_exceptions.DeadlineExceeded("test timeout")] * 5 read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=1 + chunks, sleep_time=0.05 ) - try: + with pytest.raises(core_exceptions.DeadlineExceeded) as e: await table.read_rows(query, operation_timeout=operation_timeout) - except core_exceptions.DeadlineExceeded as e: - assert ( - e.message - == f"operation_timeout of {operation_timeout:0.1f}s exceeded" - ) + assert ( + e.value.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" + ) @pytest.mark.parametrize( "per_request_t, operation_t, expected_num", @@ -1510,22 +1569,21 @@ async def test_read_rows_attempt_timeout( query = ReadRowsQuery() chunks = [core_exceptions.DeadlineExceeded("mock deadline")] - try: + with pytest.raises(core_exceptions.DeadlineExceeded) as e: await table.read_rows( query, operation_timeout=operation_t, attempt_timeout=per_request_t, ) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - if expected_num == 0: - assert retry_exc is None - else: - assert type(retry_exc) is RetryExceptionGroup - assert f"{expected_num} failed attempts" in str(retry_exc) - assert len(retry_exc.exceptions) == expected_num - for sub_exc in retry_exc.exceptions: - assert sub_exc.message == "mock deadline" + retry_exc = e.value.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" assert read_rows.call_count == expected_num # check timeouts for _, call_kwargs in read_rows.call_args_list[:-1]: @@ -1557,13 +1615,12 @@ async def test_read_rows_retryable_error(self, exc_type): ) query = ReadRowsQuery() expected_error = exc_type("mock error") - try: + with pytest.raises(core_exceptions.DeadlineExceeded) as e: await table.read_rows(query, operation_timeout=0.1) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - root_cause = retry_exc.exceptions[0] - assert type(root_cause) is exc_type - assert root_cause == expected_error + retry_exc = e.value.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error @pytest.mark.parametrize( "exc_type", @@ -1581,17 +1638,16 @@ async def test_read_rows_retryable_error(self, exc_type): ) @pytest.mark.asyncio async def test_read_rows_non_retryable_error(self, exc_type): - async with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - await table.read_rows(query, operation_timeout=0.1) - except exc_type as e: - assert e == expected_error + table = self._make_table() + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + with pytest.raises(exc_type) as e: + await table.read_rows(query, operation_timeout=0.1) + assert e.value == expected_error @pytest.mark.asyncio async def test_read_rows_revise_request(self): @@ -1618,15 +1674,14 @@ async def test_read_rows_revise_request(self): self._make_chunk(row_key=b"test_1"), core_exceptions.Aborted("mock retryable error"), ] - try: + with pytest.raises(InvalidChunk): await table.read_rows(query) - except InvalidChunk: - revise_rowset.assert_called() - first_call_kwargs = revise_rowset.call_args_list[0].kwargs - assert first_call_kwargs["row_set"] == query._to_pb(table).rows - assert first_call_kwargs["last_seen_row_key"] == b"test_1" - revised_call = read_rows.call_args_list[1].args[0] - assert revised_call.rows == return_val + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val @pytest.mark.asyncio async def test_read_rows_default_timeouts(self): @@ -1819,10 +1874,10 @@ def _make_client(self, *args, **kwargs): @pytest.mark.asyncio async def test_read_rows_sharded_empty_query(self): async with self._make_client() as client: - async with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as exc: - await table.read_rows_sharded([]) - assert "empty sharded_query" in str(exc.value) + table = client.get_table("instance", "table") + with pytest.raises(ValueError) as exc: + await table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) @pytest.mark.asyncio async def test_read_rows_sharded_multiple_queries(self): @@ -1979,11 +2034,14 @@ def _make_client(self, *args, **kwargs): return BigtableDataClientAsync(*args, **kwargs) - async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse - for value in sample_list: - yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) + pb_list = [ + SampleRowKeysResponse(row_key=s[0], offset_bytes=s[1]) for s in sample_list + ] + + return mock_grpc_call(stream_response=pb_list) @pytest.mark.asyncio async def test_sample_row_keys(self): @@ -2154,32 +2212,30 @@ async def test_mutate_row(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.return_value = None - await table.mutate_row( - "row_key", - mutation_arg, - attempt_timeout=expected_attempt_timeout, - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0].kwargs - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["row_key"] == b"row_key" - formatted_mutations = ( - [mutation._to_pb() for mutation in mutation_arg] - if isinstance(mutation_arg, list) - else [mutation_arg._to_pb()] - ) - assert kwargs["mutations"] == formatted_mutations - assert kwargs["timeout"] == expected_attempt_timeout - # make sure gapic layer is not retrying - assert kwargs["retry"] is None + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.return_value = mock_grpc_call() + await table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + # make sure gapic layer is not retrying + assert kwargs["retry"] is None @pytest.mark.parametrize( "retryable_exception", @@ -2194,20 +2250,16 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): from google.cloud.bigtable.data.exceptions import RetryExceptionGroup async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - mutation = mutations.DeleteAllFromRow() - assert mutation.is_idempotent() is True - await table.mutate_row( - "row_key", mutation, operation_timeout=0.01 - ) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + await table.mutate_row("row_key", mutation, operation_timeout=0.01) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) @pytest.mark.parametrize( "retryable_exception", @@ -2224,19 +2276,13 @@ async def test_mutate_row_non_idempotent_retryable_errors( Non-idempotent mutations should not be retried """ async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(retryable_exception): - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - assert mutation.is_idempotent() is False - await table.mutate_row( - "row_key", mutation, operation_timeout=0.2 - ) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell("family", b"qualifier", b"value", -1) + assert mutation.is_idempotent() is False + await table.mutate_row("row_key", mutation, operation_timeout=0.2) @pytest.mark.parametrize( "non_retryable_exception", @@ -2252,46 +2298,18 @@ async def test_mutate_row_non_idempotent_retryable_errors( @pytest.mark.asyncio async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - mutation = mutations.SetCell( - "family", - b"qualifier", - b"value", - timestamp_micros=1234567890, - ) - assert mutation.is_idempotent() is True - await table.mutate_row( - "row_key", mutation, operation_timeout=0.2 - ) - - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_mutate_row_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with self._make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "mutate_row", AsyncMock() - ) as read_rows: - await table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + await table.mutate_row("row_key", mutation, operation_timeout=0.2) @pytest.mark.parametrize("mutations", [[], None]) @pytest.mark.asyncio @@ -2328,10 +2346,7 @@ async def _mock_response(self, response_list): for i in range(len(response_list)) ] - async def generator(): - yield MutateRowsResponse(entries=entries) - - return generator() + return mock_grpc_call(stream_response=[MutateRowsResponse(entries=entries)]) @pytest.mark.asyncio @pytest.mark.asyncio @@ -2358,25 +2373,23 @@ async def test_bulk_mutate_rows(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None]) - bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) - await table.bulk_mutate_rows( - [bulk_mutation], - attempt_timeout=expected_attempt_timeout, - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["entries"] == [bulk_mutation._to_pb()] - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + await table.bulk_mutate_rows( + [bulk_mutation], + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None @pytest.mark.asyncio async def test_bulk_mutate_rows_multiple_entries(self): @@ -2469,24 +2482,22 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( ) async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - await table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) @pytest.mark.parametrize( "retryable_exception", @@ -2509,25 +2520,23 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( ) async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - await table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -2547,26 +2556,22 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( ) async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [retryable_exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is False - await table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell("family", b"qualifier", b"value", -1) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + await table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) @pytest.mark.parametrize( "non_retryable_exception", @@ -2589,24 +2594,22 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti ) async with self._make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - await table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, non_retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) @pytest.mark.asyncio async def test_bulk_mutate_error_index(self): @@ -2703,8 +2706,8 @@ async def test_check_and_mutate(self, gapic_result): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=gapic_result + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=gapic_result) ) row_key = b"row_key" predicate = None @@ -2775,8 +2778,8 @@ async def test_check_and_mutate_single_mutations(self): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=True) ) true_mutation = SetCell("family", b"qualifier", b"value") false_mutation = SetCell("family", b"qualifier", b"value") @@ -2803,8 +2806,8 @@ async def test_check_and_mutate_predicate_object(self): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=True) ) await table.check_and_mutate_row( b"row_key", @@ -2831,8 +2834,8 @@ async def test_check_and_mutate_mutations_parsing(self): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=True) ) await table.check_and_mutate_row( b"row_key", @@ -2885,16 +2888,20 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules """ Test that the gapic call is called with given rules """ + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + async with self._make_client() as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - await table.read_modify_write_row("key", call_rules) - assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["rules"] == expected_rules - assert found_kwargs["retry"] is None + table = client.get_table("instance", "table") + with mock.patch.object( + client._gapic_client, + "read_modify_write_row", + ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call(ReadModifyWriteRowResponse()) + await table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) @pytest.mark.asyncio @@ -2907,6 +2914,8 @@ async def test_read_modify_write_no_rules(self, rules): @pytest.mark.asyncio async def test_read_modify_write_call_defaults(self): + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + instance = "instance1" table_id = "table1" project = "project1" @@ -2916,6 +2925,9 @@ async def test_read_modify_write_call_defaults(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call( + ReadModifyWriteRowResponse() + ) await table.read_modify_write_row(row_key, mock.Mock()) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] @@ -2929,6 +2941,8 @@ async def test_read_modify_write_call_defaults(self): @pytest.mark.asyncio async def test_read_modify_write_call_overrides(self): + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + row_key = b"row_key1" expected_timeout = 12345 profile_id = "profile1" @@ -2939,6 +2953,9 @@ async def test_read_modify_write_call_overrides(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call( + ReadModifyWriteRowResponse() + ) await table.read_modify_write_row( row_key, mock.Mock(), @@ -2952,12 +2969,17 @@ async def test_read_modify_write_call_overrides(self): @pytest.mark.asyncio async def test_read_modify_write_string_key(self): + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + row_key = "string_row_key1" async with self._make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call( + ReadModifyWriteRowResponse() + ) await table.read_modify_write_row(row_key, mock.Mock()) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] @@ -2978,8 +3000,8 @@ async def test_read_modify_write_row_building(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call(mock_response) with mock.patch.object(Row, "_from_pb") as constructor_mock: - mock_gapic.return_value = mock_response await table.read_modify_write_row("key", mock.Mock()) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 446cd822e..e9d0b9c8e 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -903,7 +903,8 @@ async def test__execute_mutate_rows(self, mutate_rows): table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) + mock_metric = mock.Mock() + result = await instance._execute_mutate_rows(batch, mock_metric) assert start_operation.call_count == 1 args, kwargs = mutate_rows.call_args assert args[0] == table.client._gapic_client @@ -911,6 +912,7 @@ async def test__execute_mutate_rows(self, mutate_rows): assert args[2] == batch kwargs["operation_timeout"] == 17 kwargs["attempt_timeout"] == 13 + kwargs["metrics"] == mock_metric assert result == [] @pytest.mark.asyncio @@ -933,7 +935,7 @@ async def test__execute_mutate_rows_returns_errors(self, mutate_rows): table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) + result = await instance._execute_mutate_rows(batch, mock.Mock()) assert len(result) == 2 assert result[0] == err1 assert result[1] == err2 @@ -1058,7 +1060,7 @@ async def test_timeout_args_passed(self, mutate_rows): assert instance._operation_timeout == expected_operation_timeout assert instance._attempt_timeout == expected_attempt_timeout # make simulated gapic call - await instance._execute_mutate_rows([_make_mutation()]) + await instance._execute_mutate_rows([_make_mutation()], mock.Mock()) assert mutate_rows.call_count == 1 kwargs = mutate_rows.call_args[1] assert kwargs["operation_timeout"] == expected_operation_timeout @@ -1174,7 +1176,8 @@ async def test_customizable_retryable_errors( predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") mutation = _make_mutation(count=1, size=1) - await instance._execute_mutate_rows([mutation]) + predicate_builder_mock.reset_mock() + await instance._execute_mutate_rows([mutation], mock.Mock()) # passed in errors should be used to build the predicate predicate_builder_mock.assert_called_once_with( *expected_retryables, _MutateRowsIncomplete @@ -1182,3 +1185,37 @@ async def test_customizable_retryable_errors( retry_call_args = retry_fn_mock.call_args_list[0].args # output of if_exception_type should be sent in to retry constructor assert retry_call_args[1] is expected_predicate + + @pytest.mark.asyncio + @pytest.mark.parametrize("sleep_time,flow_size", [(0, 10), (0.1, 1), (0.01, 10)]) + async def test_flow_throttling_metric(self, sleep_time, flow_size): + """ + When there are delays due to waiting on flow control, + should be reflected in operation metric's flow_throttling_time + """ + import time + from google.cloud.bigtable.data._metrics import ( + BigtableClientSideMetricsController, + ) + from google.cloud.bigtable.data._metrics import ActiveOperationMetric + + # create mock call + async def mock_add_to_flow(): + time.sleep(sleep_time) + for _ in range(flow_size): + await asyncio.sleep(0) + yield mock.Mock() + + mock_instance = mock.Mock() + mock_instance._wait_for_batch_results.return_value = asyncio.sleep(0) + mock_instance._entries_processed_since_last_raise = 0 + mock_instance._table._metrics = BigtableClientSideMetricsController([]) + mock_instance._flow_control.add_to_flow.return_value = mock_add_to_flow() + await self._get_target_class()._flush_internal(mock_instance, []) + # get list of metrics + mock_bg_task = mock_instance._create_bg_task + metric_list = [arg[0][-1] for arg in mock_bg_task.call_args_list] + # make sure operations were set up as expected + assert len(metric_list) == flow_size + assert all([isinstance(m, ActiveOperationMetric) for m in metric_list]) + assert abs(metric_list[0].flow_throttling_time - sleep_time) < 0.002 diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py index 49a714a31..b132365d7 100644 --- a/tests/unit/data/_metrics/test_data_model.py +++ b/tests/unit/data/_metrics/test_data_model.py @@ -782,6 +782,27 @@ async def test_wrap_attempt_fn_success(self): assert len(metric.completed_attempts) == 1 assert metric.completed_attempts[0].end_status == StatusCode.OK + @pytest.mark.asyncio + async def test_wrap_attempt_fn_success_extract_call_metadata(self): + """ + When extract_call_metadata is True, should call add_response_metadata + on operation with output of wrapped function + """ + from .._async.test_client import mock_grpc_call + + metric = self._make_one(object()) + async with metric as context: + mock_call = mock_grpc_call() + inner_fn = lambda *args, **kwargs: mock_call # noqa + wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=True) + with mock.patch.object( + metric, "add_response_metadata" + ) as mock_add_metadata: + # make the wrapped call + result = await wrapped_fn() + assert result == mock_call + assert mock_add_metadata.call_count == 1 + @pytest.mark.asyncio async def test_wrap_attempt_fn_failed_extract_call_metadata(self): """ diff --git a/tests/unit/data/_metrics/test_rpcs_instrumented.py b/tests/unit/data/_metrics/test_rpcs_instrumented.py new file mode 100644 index 000000000..08958b8f5 --- /dev/null +++ b/tests/unit/data/_metrics/test_rpcs_instrumented.py @@ -0,0 +1,303 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file tests each rpc method to ensure they support metrics properly +""" + +import pytest +import mock +import datetime +from grpc import StatusCode +from grpc.aio import Metadata + +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data import mutations +from google.cloud.bigtable.data._metrics import OperationType +from google.cloud.bigtable.data._metrics.data_model import BIGTABLE_METADATA_KEY +from google.cloud.bigtable.data._metrics.data_model import SERVER_TIMING_METADATA_KEY + +from .._async.test_client import mock_grpc_call + + +RPC_ARGS = "fn_name,fn_args,gapic_fn,is_unary,expected_type" +RETRYABLE_RPCS = [ + ( + "read_rows_stream", + (ReadRowsQuery(),), + "read_rows", + False, + OperationType.READ_ROWS, + ), + ("read_rows", (ReadRowsQuery(),), "read_rows", False, OperationType.READ_ROWS), + ("read_row", (b"row_key",), "read_rows", False, OperationType.READ_ROWS), + ( + "read_rows_sharded", + ([ReadRowsQuery()],), + "read_rows", + False, + OperationType.READ_ROWS, + ), + ("row_exists", (b"row_key",), "read_rows", False, OperationType.READ_ROWS), + ("sample_row_keys", (), "sample_row_keys", False, OperationType.SAMPLE_ROW_KEYS), + ( + "mutate_row", + (b"row_key", [mutations.DeleteAllFromRow()]), + "mutate_row", + False, + OperationType.MUTATE_ROW, + ), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + False, + OperationType.BULK_MUTATE_ROWS, + ), +] +ALL_RPCS = RETRYABLE_RPCS + [ + ( + "check_and_mutate_row", + (b"row_key", None), + "check_and_mutate_row", + True, + OperationType.CHECK_AND_MUTATE, + ), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + True, + OperationType.READ_MODIFY_WRITE, + ), +] + + +@pytest.mark.parametrize(RPC_ARGS, ALL_RPCS) +@pytest.mark.asyncio +async def test_rpc_instrumented(fn_name, fn_args, gapic_fn, is_unary, expected_type): + """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import BigtableDataClientAsync + + cluster_data = "my-cluster" + zone_data = "my-zone" + expected_gfe_latency = 123 + + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}" + ) as gapic_mock: + if is_unary: + unary_response = mock.Mock() + unary_response.row.families = [] # patch for read_modify_write_row + else: + unary_response = None + # populate metadata fields + initial_metadata = Metadata( + (BIGTABLE_METADATA_KEY, f"{zone_data} {cluster_data}".encode("utf-8")) + ) + trailing_metadata = Metadata( + (SERVER_TIMING_METADATA_KEY, f"gfet4t7; dur={expected_gfe_latency*1000}") + ) + grpc_call = mock_grpc_call( + unary_response=unary_response, + initial_metadata=initial_metadata, + trailing_metadata=trailing_metadata, + ) + gapic_mock.return_value = grpc_call + async with BigtableDataClientAsync() as client: + table = TableAsync(client, "instance-id", "table-id") + # customize metrics handlers + mock_metric_handler = mock.Mock() + table._metrics.handlers = [mock_metric_handler] + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args) + # iterate stream if it exists + try: + [i async for i in maybe_stream] + except TypeError: + pass + # check for recorded metrics values + assert mock_metric_handler.on_operation_complete.call_count == 1 + found_operation = mock_metric_handler.on_operation_complete.call_args[0][0] + # make sure expected fields were set properly + assert found_operation.op_type == expected_type + now = datetime.datetime.now(datetime.timezone.utc) + assert found_operation.start_time - now < datetime.timedelta(seconds=1) + assert found_operation.duration < 0.1 + assert found_operation.duration > 0 + assert found_operation.final_status == StatusCode.OK + assert found_operation.cluster_id == cluster_data + assert found_operation.zone == zone_data + # is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded + assert found_operation.is_streaming == ("read_rows" in fn_name) + # check attempts + assert len(found_operation.completed_attempts) == 1 + found_attempt = found_operation.completed_attempts[0] + assert found_attempt.end_status == StatusCode.OK + assert found_attempt.start_time - now < datetime.timedelta(seconds=1) + assert found_attempt.duration < 0.1 + assert found_attempt.duration > 0 + assert found_attempt.start_time >= found_operation.start_time + assert found_attempt.duration <= found_operation.duration + assert found_attempt.gfe_latency == expected_gfe_latency + # first response latency not populated, because no real read_rows chunks processed + assert found_attempt.first_response_latency is None + # no application blocking time or backoff time expected + assert found_attempt.application_blocking_time == 0 + assert found_attempt.backoff_before_attempt == 0 + # no throttling expected + assert found_attempt.grpc_throttling_time == 0 + assert found_operation.flow_throttling_time == 0 + + +@pytest.mark.parametrize(RPC_ARGS, RETRYABLE_RPCS) +@pytest.mark.asyncio +async def test_rpc_instrumented_multiple_attempts( + fn_name, fn_args, gapic_fn, is_unary, expected_type +): + """check that all requests attach proper metadata headers, with a retry""" + from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import BigtableDataClientAsync + from google.api_core.exceptions import Aborted + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc.status_pb2 import Status + + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}" + ) as gapic_mock: + if is_unary: + unary_response = mock.Mock() + unary_response.row.families = [] # patch for read_modify_write_row + else: + unary_response = None + grpc_call = mock_grpc_call(unary_response=unary_response) + if gapic_fn == "mutate_rows": + # patch response to send success + grpc_call.stream_response = [ + MutateRowsResponse( + entries=[MutateRowsResponse.Entry(index=0, status=Status(code=0))] + ) + ] + gapic_mock.side_effect = [Aborted("first attempt failed"), grpc_call] + async with BigtableDataClientAsync() as client: + table = TableAsync(client, "instance-id", "table-id") + # customize metrics handlers + mock_metric_handler = mock.Mock() + table._metrics.handlers = [mock_metric_handler] + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args, retryable_errors=(Aborted,)) + # iterate stream if it exists + try: + [_ async for _ in maybe_stream] + except TypeError: + pass + # check for recorded metrics values + assert mock_metric_handler.on_operation_complete.call_count == 1 + found_operation = mock_metric_handler.on_operation_complete.call_args[0][0] + # make sure expected fields were set properly + assert found_operation.op_type == expected_type + now = datetime.datetime.now(datetime.timezone.utc) + assert found_operation.start_time - now < datetime.timedelta(seconds=1) + assert found_operation.duration < 0.1 + assert found_operation.duration > 0 + assert found_operation.final_status == StatusCode.OK + # metadata wasn't set, should see default values + assert found_operation.cluster_id == "unspecified" + assert found_operation.zone == "global" + # is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded + assert found_operation.is_streaming == ("read_rows" in fn_name) + # check attempts + assert len(found_operation.completed_attempts) == 2 + failure, success = found_operation.completed_attempts + for attempt in [success, failure]: + # check things that should be consistent across attempts + assert attempt.start_time - now < datetime.timedelta(seconds=1) + assert attempt.duration < 0.1 + assert attempt.duration > 0 + assert attempt.start_time >= found_operation.start_time + assert attempt.duration <= found_operation.duration + assert attempt.application_blocking_time == 0 + assert success.end_status == StatusCode.OK + assert failure.end_status == StatusCode.ABORTED + assert success.start_time > failure.start_time + datetime.timedelta( + seconds=failure.duration + ) + assert success.backoff_before_attempt > 0 + assert failure.backoff_before_attempt == 0 + + +@pytest.mark.asyncio +async def test_batcher_rpcs_instrumented(): + """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import BigtableDataClientAsync + + cluster_data = "my-cluster" + zone_data = "my-zone" + expected_gfe_latency = 123 + + with mock.patch( + "google.cloud.bigtable_v2.BigtableAsyncClient.mutate_rows" + ) as gapic_mock: + # populate metadata fields + initial_metadata = Metadata( + (BIGTABLE_METADATA_KEY, f"{zone_data} {cluster_data}".encode("utf-8")) + ) + trailing_metadata = Metadata( + (SERVER_TIMING_METADATA_KEY, f"gfet4t7; dur={expected_gfe_latency*1000}") + ) + grpc_call = mock_grpc_call( + initial_metadata=initial_metadata, trailing_metadata=trailing_metadata + ) + gapic_mock.return_value = grpc_call + async with BigtableDataClientAsync() as client: + table = TableAsync(client, "instance-id", "table-id") + # customize metrics handlers + mock_metric_handler = mock.Mock() + table._metrics.handlers = [mock_metric_handler] + async with table.mutations_batcher() as batcher: + await batcher.append( + mutations.RowMutationEntry( + b"row-key", [mutations.DeleteAllFromRow()] + ) + ) + # check for recorded metrics values + assert mock_metric_handler.on_operation_complete.call_count == 1 + found_operation = mock_metric_handler.on_operation_complete.call_args[0][0] + # make sure expected fields were set properly + assert found_operation.op_type == OperationType.BULK_MUTATE_ROWS + now = datetime.datetime.now(datetime.timezone.utc) + assert found_operation.start_time - now < datetime.timedelta(seconds=1) + assert found_operation.duration < 0.1 + assert found_operation.duration > 0 + assert found_operation.final_status == StatusCode.OK + assert found_operation.cluster_id == cluster_data + assert found_operation.zone == zone_data + assert found_operation.is_streaming is False + # check attempts + assert len(found_operation.completed_attempts) == 1 + found_attempt = found_operation.completed_attempts[0] + assert found_attempt.end_status == StatusCode.OK + assert found_attempt.start_time - now < datetime.timedelta(seconds=1) + assert found_attempt.duration < 0.1 + assert found_attempt.duration > 0 + assert found_attempt.start_time >= found_operation.start_time + assert found_attempt.duration <= found_operation.duration + assert found_attempt.gfe_latency == expected_gfe_latency + # first response latency not populated, because no real read_rows chunks processed + assert found_attempt.first_response_latency is None + # no application blocking time or backoff time expected + assert found_attempt.application_blocking_time == 0 + assert found_attempt.backoff_before_attempt == 0 diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 5a9c500ed..82af1fb53 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -1,3 +1,4 @@ +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -99,6 +100,67 @@ def test_attempt_timeout_w_sleeps(self): expected_value -= sleep_time +class TestBackoffGenerator: + """ + test backoff_generator wrapper. + Should wrap api_core.exponential_sleep_generator, with added history + """ + + def test_defaults(self): + """ + expect defaults: initial=0.01, multiplier=2, maximum=60 + """ + with mock.patch( + "google.cloud.bigtable.data._helpers.exponential_sleep_generator" + ) as mock_exponential_sleep_generator: + generator = _helpers.backoff_generator() + next(generator) + assert mock_exponential_sleep_generator.call_args[0] == (0.01, 2, 60) + + def test_wraps_exponential_sleep_generator(self): + """test that it wraps exponential_sleep_generator""" + args = (1, 2, 3) + with mock.patch( + "google.cloud.bigtable.data._helpers.exponential_sleep_generator" + ) as mock_exponential_sleep_generator: + expected_results = [1, 7, 9, "a", "b"] + mock_exponential_sleep_generator.return_value = iter(expected_results) + generator = _helpers.backoff_generator(*args) + for val in expected_results: + assert next(generator) == val + assert mock_exponential_sleep_generator.call_count == 1 + # args from backoff generator should be passed through + assert mock_exponential_sleep_generator.call_args == mock.call(*args) + + def test_send_gives_history(self): + """ + Calling send with an index should give back the value that was yeilded at that index + """ + with mock.patch( + "google.cloud.bigtable.data._helpers.exponential_sleep_generator" + ) as mock_exponential_sleep_generator: + expected_results = [2, 4, 6, 8, 10] + mock_exponential_sleep_generator.return_value = iter(expected_results) + generator = _helpers.backoff_generator() + # calling next should send values from wrapped iterator + assert next(generator) == 2 + assert next(generator) == 4 + assert next(generator) == 6 + # calling send with an index should return the value at that index + assert generator.send(0) == expected_results[0] + assert generator.send(2) == expected_results[2] + assert generator.send(1) == expected_results[1] + assert generator.send(0) == expected_results[0] + assert generator.send(0) == expected_results[0] + # should be able to continue iterating as normal + assert next(generator) == 8 + assert generator.send(0) == expected_results[0] + assert next(generator) == 10 + # calling an index out of range should raise an error + with pytest.raises(IndexError): + generator.send(100) + + class TestValidateTimeouts: def test_validate_timeouts_error_messages(self): with pytest.raises(ValueError) as e: diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py index 15680984b..475201c47 100644 --- a/tests/unit/data/test_read_rows_acceptance.py +++ b/tests/unit/data/test_read_rows_acceptance.py @@ -27,6 +27,7 @@ from google.cloud.bigtable.data.row import Row from ..v2_client.test_row_merger import ReadRowsTest, TestFile +from ._async.test_client import mock_grpc_call def parse_readrows_acceptance_tests(): @@ -60,19 +61,20 @@ def extract_results_from_row(row: Row): ) @pytest.mark.asyncio async def test_row_merger_scenario(test_case: ReadRowsTest): - async def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric try: results = [] instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_scenerio_stream()) + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=test_case.chunks)] ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + chunker = _ReadRowsOperationAsync.chunk_stream(instance, stream) + metric = ActiveOperationMetric(0) + metric.start_attempt() + merger = _ReadRowsOperationAsync.merge_rows(chunker, metric) async for row in merger: for cell in row: cell_result = ReadRowsTest.Result( @@ -95,36 +97,18 @@ async def _scenerio_stream(): ) @pytest.mark.asyncio async def test_read_rows_scenario(test_case: ReadRowsTest): - async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration - - def cancel(self): - pass - - return mock_stream(chunk_list) - try: client = BigtableDataClientAsync() table = client.get_table("instance", "table") results = [] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + with mock.patch.object( + table.client._gapic_client, "read_rows", mock.AsyncMock() + ) as read_rows: # run once, then return error on retry - read_rows.return_value = _make_gapic_stream(test_case.chunks) + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in test_case.chunks] + ) + read_rows.return_value = stream async for row in await table.read_rows_stream(query={}): for cell in row: cell_result = ReadRowsTest.Result( @@ -146,16 +130,14 @@ def cancel(self): @pytest.mark.asyncio async def test_out_of_order_rows(): - async def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - instance = mock.Mock() instance._remaining_count = None instance._last_yielded_row_key = b"b" chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) + instance, + mock_grpc_call(stream_response=[ReadRowsResponse(last_scanned_row_key=b"a")]), ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + merger = _ReadRowsOperationAsync.merge_rows(chunker, mock.Mock()) with pytest.raises(InvalidChunk): async for _ in merger: pass @@ -308,21 +290,14 @@ async def test_mid_cell_labels_change(): ) -async def _coro_wrapper(stream): - return stream - - async def _process_chunks(*chunks): - async def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - instance = mock.Mock() instance._remaining_count = None instance._last_yielded_row_key = None chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) + instance, mock_grpc_call(stream_response=[ReadRowsResponse(chunks=chunks)]) ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + merger = _ReadRowsOperationAsync.merge_rows(chunker, mock.Mock()) results = [] async for row in merger: results.append(row)