From 7ea3c23d7fd06a64dc87b5bad93134f05efb847b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 19 Nov 2024 16:41:57 -0800 Subject: [PATCH 1/5] chore: add cross_sync annotations (#1000) --- .github/workflows/conformance.yaml | 4 +- .kokoro/conformance.sh | 3 +- google/cloud/bigtable/data/__init__.py | 16 +- .../bigtable/data/_async/_mutate_rows.py | 40 +- .../cloud/bigtable/data/_async/_read_rows.py | 46 +- google/cloud/bigtable/data/_async/client.py | 395 ++++-- .../bigtable/data/_async/mutations_batcher.py | 190 +-- google/cloud/bigtable/data/exceptions.py | 15 + .../bigtable/data/execute_query/__init__.py | 2 + .../_async/execute_query_iterator.py | 103 +- google/cloud/bigtable/data/mutations.py | 12 + noxfile.py | 19 +- test_proxy/README.md | 7 +- ...r_data.py => client_handler_data_async.py} | 29 +- test_proxy/handlers/client_handler_legacy.py | 4 +- test_proxy/noxfile.py | 80 -- test_proxy/run_tests.sh | 3 +- test_proxy/test_proxy.py | 24 +- tests/system/data/__init__.py | 3 + tests/system/data/setup_fixtures.py | 25 - tests/system/data/test_execute_query_async.py | 283 ---- tests/system/data/test_execute_query_utils.py | 295 ---- tests/system/data/test_system.py | 937 ------------- tests/system/data/test_system_async.py | 1016 ++++++++++++++ tests/unit/data/_async/test__mutate_rows.py | 110 +- tests/unit/data/_async/test__read_rows.py | 76 +- tests/unit/data/_async/test_client.py | 1195 ++++++++++++----- .../data/_async/test_mutations_batcher.py | 806 +++++------ .../data/_async/test_read_rows_acceptance.py | 355 +++++ .../data/execute_query/_async/_testing.py | 36 - .../_async/test_query_iterator.py | 267 ++-- tests/unit/data/test_read_rows_acceptance.py | 331 ----- 32 files changed, 3430 insertions(+), 3297 deletions(-) rename test_proxy/handlers/{client_handler_data.py => client_handler_data_async.py} (90%) delete mode 100644 test_proxy/noxfile.py delete mode 100644 tests/system/data/test_execute_query_async.py delete mode 100644 tests/system/data/test_execute_query_utils.py delete mode 100644 tests/system/data/test_system.py create mode 100644 tests/system/data/test_system_async.py create mode 100644 tests/unit/data/_async/test_read_rows_acceptance.py delete mode 100644 tests/unit/data/execute_query/_async/_testing.py delete mode 100644 tests/unit/data/test_read_rows_acceptance.py diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index 68545cbec..448e1cc3a 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -26,9 +26,9 @@ jobs: matrix: test-version: [ "v0.0.2" ] py-version: [ 3.8 ] - client-type: [ "Async v3", "Legacy" ] + client-type: [ "async", "legacy" ] fail-fast: false - name: "${{ matrix.client-type }} Client / Python ${{ matrix.py-version }} / Test Tag ${{ matrix.test-version }}" + name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: - uses: actions/checkout@v4 name: "Checkout python-bigtable" diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh index 1c0b3ee0d..e85fc1394 100644 --- a/.kokoro/conformance.sh +++ b/.kokoro/conformance.sh @@ -23,7 +23,6 @@ PROXY_ARGS="" TEST_ARGS="" if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then echo "Using legacy client" - PROXY_ARGS="--legacy-client" # legacy client does not expose mutate_row. Disable those tests TEST_ARGS="-skip TestMutateRow_" fi @@ -31,7 +30,7 @@ fi # Build and start the proxy in a separate process PROXY_PORT=9999 pushd test_proxy -nohup python test_proxy.py --port $PROXY_PORT $PROXY_ARGS & +nohup python test_proxy.py --port $PROXY_PORT --client_type=$CLIENT_TYPE & proxyPID=$! popd diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 68dc22891..43ea69fdf 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -45,16 +45,30 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery +# setup custom CrossSync mappings for library +from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, +) +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync + +from google.cloud.bigtable.data._cross_sync import CrossSync + +CrossSync.add_mapping("GapicClient", BigtableAsyncClient) +CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) +CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) +CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) + __version__: str = package_version.__version__ __all__ = ( "BigtableDataClientAsync", "TableAsync", + "MutationsBatcherAsync", "RowKeySamples", "ReadRowsQuery", "RowRange", - "MutationsBatcherAsync", "Mutation", "RowMutationEntry", "SetCell", diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 914cfecf4..c5795c464 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -15,37 +15,38 @@ from __future__ import annotations from typing import Sequence, TYPE_CHECKING -from dataclasses import dataclass import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -import google.cloud.bigtable_v2.types.bigtable as types_pb import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import _EntryWithProto + +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.cloud.bigtable.data.mutations import RowMutationEntry - from google.cloud.bigtable.data._async.client import TableAsync - -@dataclass -class _EntryWithProto: - """ - A dataclass to hold a RowMutationEntry and its corresponding proto representation. - """ + if CrossSync.is_async: + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient as GapicClientType, + ) + from google.cloud.bigtable.data._async.client import TableAsync as TableType + else: + from google.cloud.bigtable_v2.services.bigtable.client import ( # type: ignore + BigtableClient as GapicClientType, + ) + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore - entry: RowMutationEntry - proto: types_pb.MutateRowsRequest.Entry +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._mutate_rows" +@CrossSync.convert_class("_MutateRowsOperation") class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -65,10 +66,11 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ + @CrossSync.convert def __init__( self, - gapic_client: "BigtableAsyncClient", - table: "TableAsync", + gapic_client: GapicClientType, + table: TableType, mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, @@ -97,7 +99,7 @@ def __init__( bt_exceptions._MutateRowsIncomplete, ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = retries.retry_target_async( + self._operation = lambda: CrossSync.retry_target( self._run_attempt, self.is_retryable, sleep_generator, @@ -112,6 +114,7 @@ def __init__( self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} + @CrossSync.convert async def start(self): """ Start the operation, and run until completion @@ -121,7 +124,7 @@ async def start(self): """ try: # trigger mutate_rows - await self._operation + await self._operation() except Exception as exc: # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations incomplete_indices = self.remaining_indices.copy() @@ -148,6 +151,7 @@ async def start(self): all_errors, len(self.mutations) ) + @CrossSync.convert async def _run_attempt(self): """ Run a single attempt of the mutate_rows rpc. diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 5617e6418..c02b3750d 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -15,13 +15,7 @@ from __future__ import annotations -from typing import ( - TYPE_CHECKING, - AsyncGenerator, - AsyncIterable, - Awaitable, - Sequence, -) +from typing import Sequence, TYPE_CHECKING from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -32,21 +26,25 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data.exceptions import _ResetRow from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator -if TYPE_CHECKING: - from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data._cross_sync import CrossSync +if TYPE_CHECKING: + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync as TableType + else: + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore -class _ResetRow(Exception): - def __init__(self, chunk): - self.chunk = chunk +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._read_rows" +@CrossSync.convert_class("_ReadRowsOperation") class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream @@ -80,7 +78,7 @@ class _ReadRowsOperationAsync: def __init__( self, query: ReadRowsQuery, - table: "TableAsync", + table: TableType, operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), @@ -102,14 +100,14 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> AsyncGenerator[Row, None]: + def start_operation(self) -> CrossSync.Iterable[Row]: """ Start the read_rows operation, retrying on retryable errors. Yields: Row: The next row in the stream """ - return retries.retry_target_stream_async( + return CrossSync.retry_target_stream( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), @@ -117,7 +115,7 @@ def start_operation(self) -> AsyncGenerator[Row, None]: exception_factory=_retry_exception_factory, ) - def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: + def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: """ Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, @@ -152,9 +150,10 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) + @CrossSync.convert() async def chunk_stream( - self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] - ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]: + self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] + ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: """ process chunks out of raw read_rows stream @@ -204,9 +203,12 @@ async def chunk_stream( current_key = None @staticmethod + @CrossSync.convert( + replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"}, + ) async def merge_rows( - chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None - ) -> AsyncGenerator[Row, None]: + chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, + ) -> CrossSync.Iterable[Row]: """ Merge chunks into rows @@ -222,7 +224,7 @@ async def merge_rows( while True: try: c = await it.__anext__() - except StopAsyncIteration: + except CrossSync.StopIteration: # stream complete return row_key = c.row_key @@ -315,7 +317,7 @@ async def merge_rows( ): raise InvalidChunk("reset row with data") continue - except StopAsyncIteration: + except CrossSync.StopIteration: raise InvalidChunk("premature end of stream") @staticmethod diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index f1f7ad1a3..d560d7e1e 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -15,88 +15,113 @@ from __future__ import annotations -import asyncio -from functools import partial -import os -import random -import sys -import time from typing import ( - TYPE_CHECKING, + cast, Any, AsyncIterable, - Dict, Optional, - Sequence, Set, - Union, - cast, + Sequence, + TYPE_CHECKING, ) + +import time import warnings +import random +import os +import concurrent.futures -from google.api_core import client_options as client_options_lib -from google.api_core import retry as retries -from google.api_core.exceptions import Aborted, DeadlineExceeded, ServiceUnavailable -import google.auth._default -import google.auth.credentials -from google.cloud.client import ClientWithProject -from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore -import grpc +from functools import partial +from grpc import Channel -from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT -from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( - ExecuteQueryIteratorAsync, -) -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync -from google.cloud.bigtable.data._async.mutations_batcher import ( - _MB_SIZE, - MutationsBatcherAsync, -) -from google.cloud.bigtable.data._helpers import ( - _CONCURRENCY_LIMIT, - TABLE_DEFAULT, - _attempt_timeout_generator, - _get_error_type, - _get_retryable_errors, - _get_timeouts, - _retry_exception_factory, - _validate_timeouts, - _WarmedInstanceKey, -) -from google.cloud.bigtable.data.exceptions import ( - FailedQueryShardError, - ShardedReadRowsExceptionGroup, -) -from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry -from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.row import Row -from google.cloud.bigtable.data.row_filters import ( - CellsRowLimitFilter, - RowFilter, - RowFilterChain, - StripValueTransformerFilter, -) from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.execute_query._parameters_formatting import ( _format_execute_query_params, ) -from google.cloud.bigtable_v2.services.bigtable.async_client import ( +from google.cloud.bigtable_v2.services.bigtable.transports.base import ( DEFAULT_CLIENT_INFO, - BigtableAsyncClient, -) -from google.cloud.bigtable_v2.services.bigtable.transports import ( - BigtableGrpcAsyncIOTransport, ) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore +from google.api_core import retry as retries +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.api_core.exceptions import Aborted + +import google.auth.credentials +import google.auth._default +from google.api_core import client_options as client_options_lib +from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import FailedQueryShardError +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _WarmedInstanceKey +from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_error_type +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry + +from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.data.row_filters import RowFilter +from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter +from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.data.row_filters import RowFilterChain + +from google.cloud.bigtable.data._cross_sync import CrossSync + +if CrossSync.is_async: + from grpc.aio import insecure_channel + from google.cloud.bigtable_v2.services.bigtable.transports import ( + BigtableGrpcAsyncIOTransport as TransportType, + ) + from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE +else: + from grpc import insecure_channel + from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore + if TYPE_CHECKING: - from google.cloud.bigtable.data._helpers import RowKeySamples, ShardedQuery + from google.cloud.bigtable.data._helpers import RowKeySamples + from google.cloud.bigtable.data._helpers import ShardedQuery + + if CrossSync.is_async: + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( + ExecuteQueryIteratorAsync, + ) +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.client" + + +@CrossSync.convert_class( + sync_name="BigtableDataClient", + add_mapping_for_name="DataClient", +) class BigtableDataClientAsync(ClientWithProject): + @CrossSync.convert( + docstring_format_vars={ + "LOOP_MESSAGE": ( + "Client should be created within an async context (running event loop)", + None, + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + None, + ), + } + ) def __init__( self, *, @@ -110,7 +135,7 @@ def __init__( """ Create a client instance for the Bigtable Data API - Client should be created within an async context (running event loop) + {LOOP_MESSAGE} Args: project: the project which the client acts on behalf of. @@ -125,7 +150,7 @@ def __init__( Client options used to set user options on the client. API Endpoint should be set through client_options. Raises: - RuntimeError: if called outside of an async context (no running event loop) + {RAISE_NO_LOOP} """ if "pool_size" in kwargs: warnings.warn("pool_size no longer supported") @@ -147,7 +172,7 @@ def __init__( stacklevel=2, ) # use insecure channel if emulator is set - custom_channel = grpc.aio.insecure_channel(self._emulator_host) + custom_channel = insecure_channel(self._emulator_host) if credentials is None: credentials = google.auth.credentials.AnonymousCredentials() if project is None: @@ -159,24 +184,26 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableAsyncClient( + self._gapic_client = CrossSync.GapicClient( credentials=credentials, client_options=client_options, client_info=client_info, - transport=lambda *args, **kwargs: BigtableGrpcAsyncIOTransport( + transport=lambda *args, **kwargs: TransportType( *args, **kwargs, channel=custom_channel ), ) - self.transport = cast( - BigtableGrpcAsyncIOTransport, self._gapic_client.transport - ) + self._is_closed = CrossSync.Event() + self.transport = cast(TransportType, self._gapic_client.transport) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_task: asyncio.Task[None] | None = None + self._channel_refresh_task: CrossSync.Task[None] | None = None + self._executor = ( + concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None + ) if self._emulator_host is None: # attempt to start background channel refresh tasks try: @@ -194,42 +221,58 @@ def _client_version() -> str: """ Helper function to return the client version string for this client """ - return f"{google.cloud.bigtable.__version__}-data-async" - + version_str = f"{google.cloud.bigtable.__version__}-data" + if CrossSync.is_async: + version_str += "-async" + return version_str + + @CrossSync.convert( + docstring_format_vars={ + "RAISE_NO_LOOP": ( + "RuntimeError: if not called in an asyncio event loop", + "None", + ) + } + ) def _start_background_channel_refresh(self) -> None: """ Starts a background task to ping and warm grpc channel Raises: - RuntimeError: if not called in an asyncio event loop + {RAISE_NO_LOOP} """ - if not self._channel_refresh_task and not self._emulator_host: - # raise RuntimeError if there is no event loop - asyncio.get_running_loop() - self._channel_refresh_task = asyncio.create_task(self._manage_channel()) - if sys.version_info >= (3, 8): - # task names supported in Python 3.8+ - self._channel_refresh_task.set_name( - f"{self.__class__.__name__} channel refresh" - ) + if ( + not self._channel_refresh_task + and not self._emulator_host + and not self._is_closed.is_set() + ): + # raise error if not in an event loop in async client + CrossSync.verify_async_event_loop() + self._channel_refresh_task = CrossSync.create_task( + self._manage_channel, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh", + ) - async def close(self, timeout: float = 2.0): + @CrossSync.convert + async def close(self, timeout: float | None = 2.0): """ Cancel all background tasks """ - if self._channel_refresh_task: + self._is_closed.set() + if self._channel_refresh_task is not None: self._channel_refresh_task.cancel() - try: - await asyncio.wait_for(self._channel_refresh_task, timeout=timeout) - except asyncio.CancelledError: - pass + await CrossSync.wait([self._channel_refresh_task], timeout=timeout) await self.transport.close() + if self._executor: + self._executor.shutdown(wait=False) self._channel_refresh_task = None + @CrossSync.convert async def _ping_and_warm_instances( self, instance_key: _WarmedInstanceKey | None = None, - channel: grpc.aio.Channel | None = None, + channel: Channel | None = None, ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -251,23 +294,26 @@ async def _ping_and_warm_instances( request_serializer=PingAndWarmRequest.serialize, ) # prepare list of coroutines to run - tasks = [] - for instance_name, table_name, app_profile_id in instance_list: - metadata_str = f"name={instance_name}" - if app_profile_id is not None: - metadata_str = f"{metadata_str}&app_profile_id={app_profile_id}" - tasks.append( - ping_rpc( - request={"name": instance_name, "app_profile_id": app_profile_id}, - metadata=[("x-goog-request-params", metadata_str)], - wait_for_ready=True, - ) + partial_list = [ + partial( + ping_rpc, + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], + wait_for_ready=True, ) - # execute coroutines in parallel - result_list = await asyncio.gather(*tasks, return_exceptions=True) - # return None in place of empty successful responses + for (instance_name, table_name, app_profile_id) in instance_list + ] + result_list = await CrossSync.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor + ) return [r or None for r in result_list] + @CrossSync.convert async def _manage_channel( self, refresh_interval_min: float = 60 * 35, @@ -275,7 +321,7 @@ async def _manage_channel( grace_period: float = 60 * 10, ) -> None: """ - Background coroutine that periodically refreshes and warms a grpc channel + Background task that periodically refreshes and warms a grpc channel The backend will automatically close channels after 60 minutes, so `refresh_interval` + `grace_period` should be < 60 minutes @@ -300,22 +346,41 @@ async def _manage_channel( # warm the current channel immediately await self._ping_and_warm_instances(channel=self.transport.grpc_channel) # continuously refresh the channel every `refresh_interval` seconds - while True: - await asyncio.sleep(next_sleep) - start_timestamp = time.time() + while not self._is_closed.is_set(): + await CrossSync.event_wait( + self._is_closed, + next_sleep, + async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close + ) + if self._is_closed.is_set(): + # don't refresh if client is closed + break + start_timestamp = time.monotonic() # prepare new channel for use old_channel = self.transport.grpc_channel new_channel = self.transport.create_channel() await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel - await old_channel.close(grace_period) - # subtract the time spent waiting for the channel to be replaced + # give old_channel a chance to complete existing rpcs + if CrossSync.is_async: + await old_channel.close(grace_period) + else: + if grace_period: + self._is_closed.wait(grace_period) # type: ignore + old_channel.close() # type: ignore + # subtract thed time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.time() - start_timestamp) + next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) + @CrossSync.convert( + replace_symbols={ + "TableAsync": "Table", + "ExecuteQueryIteratorAsync": "ExecuteQueryIterator", + } + ) async def _register_instance( - self, instance_id: str, owner: Union[TableAsync, ExecuteQueryIteratorAsync] + self, instance_id: str, owner: TableAsync | ExecuteQueryIteratorAsync ) -> None: """ Registers an instance with the client, and warms the channel for the instance @@ -344,8 +409,14 @@ async def _register_instance( # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() + @CrossSync.convert( + replace_symbols={ + "TableAsync": "Table", + "ExecuteQueryIteratorAsync": "ExecuteQueryIterator", + } + ) async def _remove_instance_registration( - self, instance_id: str, owner: Union[TableAsync, ExecuteQueryIteratorAsync] + self, instance_id: str, owner: TableAsync | "ExecuteQueryIteratorAsync" ) -> bool: """ Removes an instance from the client's registered instances, to prevent @@ -374,11 +445,26 @@ async def _remove_instance_registration( except KeyError: return False + @CrossSync.convert( + replace_symbols={"TableAsync": "Table"}, + docstring_format_vars={ + "LOOP_MESSAGE": ( + "Must be created within an async context (running event loop)", + "", + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + "None", + ), + }, + ) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed directly to the TableAsync constructor. + {LOOP_MESSAGE} + Args: instance_id: The Bigtable instance ID to associate with this client. instance_id is combined with the client's project to fully @@ -411,17 +497,20 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs Returns: TableAsync: a table instance for making data API requests Raises: - RuntimeError: if called outside of an async context (no running event loop) + {RAISE_NO_LOOP} """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + @CrossSync.convert( + replace_symbols={"ExecuteQueryIteratorAsync": "ExecuteQueryIterator"} + ) async def execute_query( self, query: str, instance_id: str, *, - parameters: Dict[str, ExecuteQueryValueType] | None = None, - parameter_types: Dict[str, SqlType.Type] | None = None, + parameters: dict[str, ExecuteQueryValueType] | None = None, + parameter_types: dict[str, SqlType.Type] | None = None, app_profile_id: str | None = None, operation_timeout: float = 600, attempt_timeout: float | None = 20, @@ -491,7 +580,7 @@ async def execute_query( "proto_format": {}, } - return ExecuteQueryIteratorAsync( + return CrossSync.ExecuteQueryIterator( self, instance_id, app_profile_id, @@ -501,15 +590,18 @@ async def execute_query( retryable_excs=retryable_excs, ) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self + @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) +@CrossSync.convert_class(sync_name="Table", add_mapping_for_name="Table") class TableAsync: """ Main Data API surface @@ -518,6 +610,19 @@ class TableAsync: each call """ + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}, + docstring_format_vars={ + "LOOP_MESSAGE": ( + "Must be created within an async context (running event loop)", + "", + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + "None", + ), + }, + ) def __init__( self, client: BigtableDataClientAsync, @@ -548,7 +653,7 @@ def __init__( """ Initialize a Table instance - Must be created within an async context (running event loop) + {LOOP_MESSAGE} Args: instance_id: The Bigtable instance ID to associate with this client. @@ -580,7 +685,7 @@ def __init__( encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) Raises: - RuntimeError: if called outside of an async context (no running event loop) + {RAISE_NO_LOOP} """ # NOTE: any changes to the signature of this method should also be reflected # in client.get_table() @@ -626,17 +731,19 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () - - # raises RuntimeError if called outside of an async context (no running event loop) try: - self._register_instance_task = asyncio.create_task( - self.client._register_instance(instance_id, self) + self._register_instance_future = CrossSync.create_task( + self.client._register_instance, + self.instance_id, + self, + sync_executor=self.client._executor, ) except RuntimeError as e: raise RuntimeError( f"{self.__class__.__name__} must be created within an async event loop context." ) from e + @CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable"}) async def read_rows_stream( self, query: ReadRowsQuery, @@ -678,7 +785,7 @@ async def read_rows_stream( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperationAsync( + row_merger = CrossSync._ReadRowsOperation( query, self, operation_timeout=operation_timeout, @@ -687,6 +794,7 @@ async def read_rows_stream( ) return row_merger.start_operation() + @CrossSync.convert async def read_rows( self, query: ReadRowsQuery, @@ -734,6 +842,7 @@ async def read_rows( ) return [row async for row in row_generator] + @CrossSync.convert async def read_row( self, row_key: str | bytes, @@ -783,6 +892,7 @@ async def read_row( return None return results[0] + @CrossSync.convert async def read_rows_sharded( self, sharded_query: ShardedQuery, @@ -833,8 +943,9 @@ async def read_rows_sharded( ) # limit the number of concurrent requests using a semaphore - concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT) + concurrency_sem = CrossSync.Semaphore(_CONCURRENCY_LIMIT) + @CrossSync.convert async def read_rows_with_semaphore(query): async with concurrency_sem: # calculate new timeout based on time left in overall operation @@ -850,8 +961,14 @@ async def read_rows_with_semaphore(query): retryable_errors=retryable_errors, ) - routine_list = [read_rows_with_semaphore(query) for query in sharded_query] - batch_result = await asyncio.gather(*routine_list, return_exceptions=True) + routine_list = [ + partial(read_rows_with_semaphore, query) for query in sharded_query + ] + batch_result = await CrossSync.gather_partials( + routine_list, + return_exceptions=True, + sync_executor=self.client._executor, + ) # collect results and errors error_dict = {} @@ -878,6 +995,7 @@ async def read_rows_with_semaphore(query): ) return results_list + @CrossSync.convert async def row_exists( self, row_key: str | bytes, @@ -926,6 +1044,7 @@ async def row_exists( ) return len(results) > 0 + @CrossSync.convert async def sample_row_keys( self, *, @@ -977,7 +1096,7 @@ async def sample_row_keys( sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - # prepare request + @CrossSync.convert async def execute_rpc(): results = await self.client._gapic_client.sample_row_keys( table_name=self.table_name, @@ -987,7 +1106,7 @@ async def execute_rpc(): ) return [(s.row_key, s.offset_bytes) async for s in results] - return await retries.retry_target_async( + return await CrossSync.retry_target( execute_rpc, predicate, sleep_generator, @@ -995,6 +1114,7 @@ async def execute_rpc(): exception_factory=_retry_exception_factory, ) + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def mutations_batcher( self, *, @@ -1007,7 +1127,7 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ) -> MutationsBatcherAsync: + ) -> "MutationsBatcherAsync": """ Returns a new mutations batcher instance. @@ -1032,7 +1152,7 @@ def mutations_batcher( Returns: MutationsBatcherAsync: a MutationsBatcherAsync context manager that can batch requests """ - return MutationsBatcherAsync( + return CrossSync.MutationsBatcher( self, flush_interval=flush_interval, flush_limit_mutation_count=flush_limit_mutation_count, @@ -1044,6 +1164,7 @@ def mutations_batcher( batch_retryable_errors=batch_retryable_errors, ) + @CrossSync.convert async def mutate_row( self, row_key: str | bytes, @@ -1113,7 +1234,7 @@ async def mutate_row( timeout=attempt_timeout, retry=None, ) - return await retries.retry_target_async( + return await CrossSync.retry_target( target, predicate, sleep_generator, @@ -1121,6 +1242,7 @@ async def mutate_row( exception_factory=_retry_exception_factory, ) + @CrossSync.convert async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], @@ -1166,7 +1288,7 @@ async def bulk_mutate_rows( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperationAsync( + operation = CrossSync._MutateRowsOperation( self.client._gapic_client, self, mutation_entries, @@ -1176,6 +1298,7 @@ async def bulk_mutate_rows( ) await operation.start() + @CrossSync.convert async def check_and_mutate_row( self, row_key: str | bytes, @@ -1240,6 +1363,7 @@ async def check_and_mutate_row( ) return result.predicate_matched + @CrossSync.convert async def read_modify_write_row( self, row_key: str | bytes, @@ -1288,13 +1412,16 @@ async def read_modify_write_row( # construct Row from result return Row._from_pb(result.row) + @CrossSync.convert async def close(self): """ Called to close the Table instance and release any resources held by it. """ - self._register_instance_task.cancel() + if self._register_instance_future: + self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1302,9 +1429,11 @@ async def __aenter__(self): Ensure registration task has time to run, so that grpc channels will be warmed for the specified instance """ - await self._register_instance_task + if self._register_instance_future: + await self._register_instance_future return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 76d13f00b..65070c880 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,32 +14,40 @@ # from __future__ import annotations -from typing import Any, Sequence, TYPE_CHECKING -import asyncio +from typing import Sequence, TYPE_CHECKING import atexit import warnings from collections import deque +import concurrent.futures -from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async._mutate_rows import ( +from google.cloud.bigtable.data.mutations import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._cross_sync import CrossSync + if TYPE_CHECKING: - from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data.mutations import RowMutationEntry + + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync as TableType + else: + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore + +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.mutations_batcher" # used to make more readable default values _MB_SIZE = 1024 * 1024 +@CrossSync.convert_class(sync_name="_FlowControl", add_mapping_for_name="_FlowControl") class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -70,7 +78,7 @@ def __init__( raise ValueError("max_mutation_count must be greater than 0") if self._max_mutation_bytes < 1: raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = asyncio.Condition() + self._capacity_condition = CrossSync.Condition() self._in_flight_mutation_count = 0 self._in_flight_mutation_bytes = 0 @@ -96,6 +104,7 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: new_count = self._in_flight_mutation_count + additional_count return new_size <= acceptable_size and new_count <= acceptable_count + @CrossSync.convert async def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: @@ -117,6 +126,7 @@ async def remove_from_flow( async with self._capacity_condition: self._capacity_condition.notify_all() + @CrossSync.convert async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): """ Generator function that registers mutations with flow control. As mutations @@ -166,6 +176,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] +@CrossSync.convert_class(sync_name="MutationsBatcher") class MutationsBatcherAsync: """ Allows users to send batches using context manager API: @@ -199,7 +210,7 @@ class MutationsBatcherAsync: def __init__( self, - table: "TableAsync", + table: TableType, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -218,11 +229,11 @@ def __init__( batch_retryable_errors, table ) - self.closed: bool = False + self._closed = CrossSync.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 - self._flow_control = _FlowControlAsync( + self._flow_control = CrossSync._FlowControl( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -231,8 +242,22 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = self._start_flush_timer(flush_interval) - self._flush_jobs: set[asyncio.Future[None]] = set() + # used by sync class to run mutate_rows operations + self._sync_rpc_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=8) + if not CrossSync.is_async + else None + ) + # used by sync class to manage flush_internal tasks + self._sync_flush_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=1) + if not CrossSync.is_async + else None + ) + self._flush_timer = CrossSync.create_task( + self._timer_routine, flush_interval, sync_executor=self._sync_flush_executor + ) + self._flush_jobs: set[CrossSync.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 @@ -245,7 +270,8 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) - def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: + @CrossSync.convert + async def _timer_routine(self, interval: float | None) -> None: """ Set up a background task to flush the batcher every interval seconds @@ -254,27 +280,18 @@ def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: Args: flush_interval: Automatically flush every flush_interval seconds. If None, no time-based flushing is performed. - Returns: - asyncio.Future[None]: future representing the background task """ - if interval is None or self.closed: - empty_future: asyncio.Future[None] = asyncio.Future() - empty_future.set_result(None) - return empty_future - - async def timer_routine(self, interval: float): - """ - Triggers new flush tasks every `interval` seconds - """ - while not self.closed: - await asyncio.sleep(interval) - # add new flush task to list - if not self.closed and self._staged_entries: - self._schedule_flush() - - timer_task = asyncio.create_task(timer_routine(self, interval)) - return timer_task + if not interval or interval <= 0: + return None + while not self._closed.is_set(): + # wait until interval has passed, or until closed + await CrossSync.event_wait( + self._closed, timeout=interval, async_break_early=False + ) + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() + @CrossSync.convert async def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue @@ -286,7 +303,7 @@ async def append(self, mutation_entry: RowMutationEntry): ValueError: if an invalid mutation type is added """ # TODO: return a future to track completion of this entry - if self.closed: + if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): # type: ignore raise ValueError( @@ -302,25 +319,29 @@ async def append(self, mutation_entry: RowMutationEntry): ): self._schedule_flush() # yield to the event loop to allow flush to run - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() - def _schedule_flush(self) -> asyncio.Future[None] | None: + def _schedule_flush(self) -> CrossSync.Future[None] | None: """ Update the flush task to include the latest staged entries Returns: - asyncio.Future[None] | None: + Future[None] | None: future representing the background task, if started """ if self._staged_entries: entries, self._staged_entries = self._staged_entries, [] self._staged_count, self._staged_bytes = 0, 0 - new_task = self._create_bg_task(self._flush_internal, entries) - new_task.add_done_callback(self._flush_jobs.remove) - self._flush_jobs.add(new_task) + new_task = CrossSync.create_task( + self._flush_internal, entries, sync_executor=self._sync_flush_executor + ) + if not new_task.done(): + self._flush_jobs.add(new_task) + new_task.add_done_callback(self._flush_jobs.remove) return new_task return None + @CrossSync.convert async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ Flushes a set of mutations to the server, and updates internal state @@ -329,9 +350,11 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): new_entries list of RowMutationEntry objects to flush """ # flush new entries - in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + in_process_requests: list[CrossSync.Future[list[FailedMutationEntryError]]] = [] async for batch in self._flow_control.add_to_flow(new_entries): - batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + batch_task = CrossSync.create_task( + self._execute_mutate_rows, batch, sync_executor=self._sync_rpc_executor + ) in_process_requests.append(batch_task) # wait for all inflight requests to complete found_exceptions = await self._wait_for_batch_results(*in_process_requests) @@ -339,6 +362,7 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) + @CrossSync.convert async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: @@ -355,7 +379,7 @@ async def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _MutateRowsOperationAsync( + operation = CrossSync._MutateRowsOperation( self._table.client._gapic_client, self._table, batch, @@ -419,10 +443,12 @@ def _raise_exceptions(self): entry_count=entry_count, ) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """Allow use of context manager API""" return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc, tb): """ Allow use of context manager API. @@ -431,19 +457,30 @@ async def __aexit__(self, exc_type, exc, tb): """ await self.close() + @property + def closed(self) -> bool: + """ + Returns: + - True if the batcher is closed, False otherwise + """ + return self._closed.is_set() + + @CrossSync.convert async def close(self): """ Flush queue and clean up resources """ - self.closed = True + self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if self._flush_jobs: - await asyncio.gather(*self._flush_jobs, return_exceptions=True) - try: - await self._flush_timer - except asyncio.CancelledError: - pass + # shut down executors + if self._sync_flush_executor: + with self._sync_flush_executor: + self._sync_flush_executor.shutdown(wait=True) + if self._sync_rpc_executor: + with self._sync_rpc_executor: + self._sync_rpc_executor.shutdown(wait=True) + await CrossSync.wait([*self._flush_jobs, self._flush_timer]) atexit.unregister(self._on_exit) # raise unreported exceptions self._raise_exceptions() @@ -452,32 +489,17 @@ def _on_exit(self): """ Called when program is exited. Raises warning if unflushed mutations remain """ - if not self.closed and self._staged_entries: + if not self._closed.is_set() and self._staged_entries: warnings.warn( f"MutationsBatcher for table {self._table.table_name} was not closed. " f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) @staticmethod - def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]: - """ - Create a new background task, and return a future - - This method wraps asyncio to make it easier to maintain subclasses - with different concurrency models. - - Args: - func: function to execute in background task - *args: positional arguments to pass to func - **kwargs: keyword arguments to pass to func - Returns: - asyncio.Future: Future object representing the background task - """ - return asyncio.create_task(func(*args, **kwargs)) - - @staticmethod + @CrossSync.convert async def _wait_for_batch_results( - *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None], + *tasks: CrossSync.Future[list[FailedMutationEntryError]] + | CrossSync.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, @@ -494,19 +516,19 @@ async def _wait_for_batch_results( """ if not tasks: return [] - all_results = await asyncio.gather(*tasks, return_exceptions=True) - found_errors = [] - for result in all_results: - if isinstance(result, Exception): - # will receive direct Exception objects if request task fails - found_errors.append(result) - elif isinstance(result, BaseException): - # BaseException not expected from grpc calls. Raise immediately - raise result - elif result: - # completed requests will return a list of FailedMutationEntryError - for e in result: - # strip index information - e.index = None - found_errors.extend(result) - return found_errors + exceptions: list[Exception] = [] + for task in tasks: + if CrossSync.is_async: + # futures don't need to be awaited in sync mode + await task + try: + exc_list = task.result() + if exc_list: + # expect a list of FailedMutationEntryError objects + for exc in exc_list: + # strip index information + exc.index = None + exceptions.extend(exc_list) + except Exception as e: + exceptions.append(e) + return exceptions diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py index 95cd44f2c..62f0b62fc 100644 --- a/google/cloud/bigtable/data/exceptions.py +++ b/google/cloud/bigtable/data/exceptions.py @@ -41,6 +41,21 @@ class _RowSetComplete(Exception): pass +class _ResetRow(Exception): # noqa: F811 + """ + Internal exception for _ReadRowsOperation + + Denotes that the server sent a reset_row marker, telling the client to drop + all previous chunks for row_key and re-read from the beginning. + + Args: + chunk: the reset_row chunk + """ + + def __init__(self, chunk): + self.chunk = chunk + + class _MutateRowsIncomplete(RuntimeError): """ Exception raised when a mutate_rows call has unfinished work. diff --git a/google/cloud/bigtable/data/execute_query/__init__.py b/google/cloud/bigtable/data/execute_query/__init__.py index 94af7d1cd..0ff258365 100644 --- a/google/cloud/bigtable/data/execute_query/__init__.py +++ b/google/cloud/bigtable/data/execute_query/__init__.py @@ -25,7 +25,9 @@ QueryResultRow, Struct, ) +from google.cloud.bigtable.data._cross_sync import CrossSync +CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) __all__ = [ "ExecuteQueryValueType", diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 6146ad451..ba82bbcca 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -14,10 +14,8 @@ from __future__ import annotations -import asyncio from typing import ( Any, - AsyncIterator, Dict, Optional, Sequence, @@ -43,40 +41,31 @@ ExecuteQueryRequest as ExecuteQueryRequestPB, ) +from google.cloud.bigtable.data._cross_sync import CrossSync + if TYPE_CHECKING: - from google.cloud.bigtable.data import BigtableDataClientAsync + if CrossSync.is_async: + from google.cloud.bigtable.data import BigtableDataClientAsync as DataClientType +__CROSS_SYNC_OUTPUT__ = ( + "google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator" +) -class ExecuteQueryIteratorAsync: - """ - ExecuteQueryIteratorAsync handles collecting streaming responses from the - ExecuteQuery RPC and parsing them to QueryResultRows. - - ExecuteQueryIteratorAsync implements Asynchronous Iterator interface and can - be used with "async for" syntax. It is also a context manager. - - It is **not thread-safe**. It should not be used by multiple asyncio Tasks. - - Args: - client: bigtable client - instance_id: id of the instance on which the query is executed - request_body: dict representing the body of the ExecuteQueryRequest - attempt_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to 600 seconds. - operation_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the 20 seconds. If None, defaults to operation_timeout. - req_metadata: metadata used while sending the gRPC request - retryable_excs: a list of errors that will be retried if encountered. - Raises: - RuntimeError: if the instance is not created within an async event loop context. - """ +@CrossSync.convert_class(sync_name="ExecuteQueryIterator") +class ExecuteQueryIteratorAsync: + @CrossSync.convert( + docstring_format_vars={ + "NO_LOOP": ( + "RuntimeError: if the instance is not created within an async event loop context.", + "None", + ), + "TASK_OR_THREAD": ("asyncio Tasks", "threads"), + } + ) def __init__( self, - client: BigtableDataClientAsync, + client: DataClientType, instance_id: str, app_profile_id: Optional[str], request_body: Dict[str, Any], @@ -85,6 +74,25 @@ def __init__( req_metadata: Sequence[Tuple[str, str]] = (), retryable_excs: Sequence[type[Exception]] = (), ) -> None: + """ + Collects responses from ExecuteQuery requests and parses them into QueryResultRows. + + It is **not thread-safe**. It should not be used by multiple {TASK_OR_THREAD}. + + Args: + client: bigtable client + instance_id: id of the instance on which the query is executed + request_body: dict representing the body of the ExecuteQueryRequest + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget + req_metadata: metadata used while sending the gRPC request + retryable_excs: a list of errors that will be retried if encountered. + Raises: + {NO_LOOP} + """ self._table_name = None self._app_profile_id = app_profile_id self._client = client @@ -98,8 +106,7 @@ def __init__( self._attempt_timeout_gen = _attempt_timeout_generator( attempt_timeout, operation_timeout ) - retryable_excs = retryable_excs or [] - self._async_stream = retries.retry_target_stream_async( + self._stream = CrossSync.retry_target_stream( self._make_request_with_resume_token, retries.if_exception_type(*retryable_excs), retries.exponential_sleep_generator(0.01, 60, multiplier=2), @@ -109,8 +116,11 @@ def __init__( self._req_metadata = req_metadata try: - self._register_instance_task = asyncio.create_task( - self._client._register_instance(instance_id, self) + self._register_instance_task = CrossSync.create_task( + self._client._register_instance, + instance_id, + self, + sync_executor=self._client._executor, ) except RuntimeError as e: raise RuntimeError( @@ -132,6 +142,7 @@ def table_name(self) -> Optional[str]: """Returns the table_name of the iterator.""" return self._table_name + @CrossSync.convert async def _make_request_with_resume_token(self): """ perfoms the rpc call using the correct resume token. @@ -150,23 +161,25 @@ async def _make_request_with_resume_token(self): retry=None, ) - async def _await_metadata(self) -> None: + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) + async def _fetch_metadata(self) -> None: """ If called before the first response was recieved, the first response - is awaited as part of this call. + is retrieved as part of this call. """ if self._byte_cursor.metadata is None: - metadata_msg = await self._async_stream.__anext__() + metadata_msg = await self._stream.__anext__() self._byte_cursor.consume_metadata(metadata_msg) - async def _next_impl(self) -> AsyncIterator[QueryResultRow]: + @CrossSync.convert + async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]: """ Generator wrapping the response stream which parses the stream results and returns full `QueryResultRow`s. """ - await self._await_metadata() + await self._fetch_metadata() - async for response in self._async_stream: + async for response in self._stream: try: bytes_to_parse = self._byte_cursor.consume(response) if bytes_to_parse is None: @@ -185,14 +198,17 @@ async def _next_impl(self) -> AsyncIterator[QueryResultRow]: yield result await self.close() + @CrossSync.convert(sync_name="__next__", replace_symbols={"__anext__": "__next__"}) async def __anext__(self) -> QueryResultRow: if self._is_closed: - raise StopAsyncIteration + raise CrossSync.StopIteration return await self._result_generator.__anext__() + @CrossSync.convert(sync_name="__iter__") def __aiter__(self): return self + @CrossSync.convert async def metadata(self) -> Optional[Metadata]: """ Returns query metadata from the server or None if the iterator was @@ -203,11 +219,12 @@ async def metadata(self) -> Optional[Metadata]: # Metadata should be present in the first response in a stream. if self._byte_cursor.metadata is None: try: - await self._await_metadata() - except StopIteration: + await self._fetch_metadata() + except CrossSync.StopIteration: return None return self._byte_cursor.metadata + @CrossSync.convert async def close(self) -> None: """ Cancel all background tasks. Should be called all rows were processed. diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py index 335a15e12..2f4e441ed 100644 --- a/google/cloud/bigtable/data/mutations.py +++ b/google/cloud/bigtable/data/mutations.py @@ -366,3 +366,15 @@ def _from_dict(cls, input_dict: dict[str, Any]) -> RowMutationEntry: Mutation._from_dict(mutation) for mutation in input_dict["mutations"] ], ) + + +@dataclass +class _EntryWithProto: + """ + A dataclass to hold a RowMutationEntry and its corresponding proto representation. + + Used in _MutateRowsOperation to avoid repeated conversion of RowMutationEntry to proto. + """ + + entry: RowMutationEntry + proto: types_pb.MutateRowsRequest.Entry diff --git a/noxfile.py b/noxfile.py index 4dfebe068..f6a2291fc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -157,6 +157,8 @@ def mypy(session): "tests/system/v2_client", "--exclude", "tests/unit/v2_client", + "--disable-error-code", + "func-returns-value", # needed for CrossSync.rm_aio ) @@ -294,9 +296,8 @@ def system_emulated(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -def conformance(session): - TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git" - CLONE_REPO_DIR = "cloud-bigtable-clients-test" +@nox.parametrize("client_type", ["async"]) +def conformance(session, client_type): # install dependencies constraints_path = str( CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" @@ -304,11 +305,13 @@ def conformance(session): install_unittest_dependencies(session, "-c", constraints_path) with session.chdir("test_proxy"): # download the conformance test suite - clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR) - if not os.path.exists(clone_dir): - print("downloading copy of test repo") - session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR, external=True) - session.run("bash", "-e", "run_tests.sh", external=True) + session.run( + "bash", + "-e", + "run_tests.sh", + external=True, + env={"CLIENT_TYPE": client_type}, + ) @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) diff --git a/test_proxy/README.md b/test_proxy/README.md index 08741fd5d..266fba7cd 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -8,7 +8,7 @@ You can run the conformance tests in a single line by calling `nox -s conformanc ``` -cd python-bigtable/test_proxy +cd python-bigtable nox -s conformance ``` @@ -30,10 +30,11 @@ cd python-bigtable/test_proxy python test_proxy.py --port 8080 ``` -You can run the test proxy against the previous `v2` client by running it with the `--legacy-client` flag: +By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. +Valid options are `async` and `legacy`. ``` -python test_proxy.py --legacy-client +python test_proxy.py --client_type=legacy ``` ### Run the test cases diff --git a/test_proxy/handlers/client_handler_data.py b/test_proxy/handlers/client_handler_data_async.py similarity index 90% rename from test_proxy/handlers/client_handler_data.py rename to test_proxy/handlers/client_handler_data_async.py index 43ff5d634..7f6cc413f 100644 --- a/test_proxy/handlers/client_handler_data.py +++ b/test_proxy/handlers/client_handler_data_async.py @@ -18,8 +18,15 @@ from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.cloud.bigtable.data import BigtableDataClientAsync +from google.cloud.bigtable.data._cross_sync import CrossSync +if not CrossSync.is_async: + from client_handler_data_async import error_safe +__CROSS_SYNC_OUTPUT__ = "test_proxy.handlers.client_handler_data_sync_autogen" + + +@CrossSync.drop def error_safe(func): """ Catch and pass errors back to the grpc_server_process @@ -37,6 +44,7 @@ async def wrapper(self, *args, **kwargs): return wrapper +@CrossSync.drop def encode_exception(exc): """ Encode an exception or chain of exceptions to pass back to grpc_handler @@ -68,7 +76,8 @@ def encode_exception(exc): return result -class TestProxyClientHandler: +@CrossSync.convert_class("TestProxyClientHandler") +class TestProxyClientHandlerAsync: """ Implements the same methods as the grpc server, but handles the client library side of the request. @@ -90,7 +99,7 @@ def __init__( self.closed = False # use emulator os.environ[BIGTABLE_EMULATOR] = data_target - self.client = BigtableDataClientAsync(project=project_id) + self.client = CrossSync.DataClient(project=project_id) self.instance_id = instance_id self.app_profile_id = app_profile_id self.per_operation_timeout = per_operation_timeout @@ -105,7 +114,7 @@ async def ReadRows(self, request, **kwargs): app_profile_id = self.app_profile_id or request.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - result_list = await table.read_rows(request, **kwargs) + result_list = CrossSync.rm_aio(await table.read_rows(request, **kwargs)) # pack results back into protobuf-parsable format serialized_response = [row._to_dict() for row in result_list] return serialized_response @@ -116,7 +125,7 @@ async def ReadRow(self, row_key, **kwargs): app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - result_row = await table.read_row(row_key, **kwargs) + result_row = CrossSync.rm_aio(await table.read_row(row_key, **kwargs)) # pack results back into protobuf-parsable format if result_row: return result_row._to_dict() @@ -132,7 +141,7 @@ async def MutateRow(self, request, **kwargs): kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 row_key = request["row_key"] mutations = [Mutation._from_dict(d) for d in request["mutations"]] - await table.mutate_row(row_key, mutations, **kwargs) + CrossSync.rm_aio(await table.mutate_row(row_key, mutations, **kwargs)) return "OK" @error_safe @@ -143,7 +152,7 @@ async def BulkMutateRows(self, request, **kwargs): table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 entry_list = [RowMutationEntry._from_dict(entry) for entry in request["entries"]] - await table.bulk_mutate_rows(entry_list, **kwargs) + CrossSync.rm_aio(await table.bulk_mutate_rows(entry_list, **kwargs)) return "OK" @error_safe @@ -171,13 +180,13 @@ async def CheckAndMutateRow(self, request, **kwargs): # invalid mutation type. Conformance test may be sending generic empty request false_mutations.append(SetCell("", "", "", 0)) predicate_filter = request.get("predicate_filter", None) - result = await table.check_and_mutate_row( + result = CrossSync.rm_aio(await table.check_and_mutate_row( row_key, predicate_filter, true_case_mutations=true_mutations, false_case_mutations=false_mutations, **kwargs, - ) + )) return result @error_safe @@ -197,7 +206,7 @@ async def ReadModifyWriteRow(self, request, **kwargs): else: new_rule = IncrementRule(rule_dict["family_name"], qualifier, rule_dict["increment_amount"]) rules.append(new_rule) - result = await table.read_modify_write_row(row_key, rules, **kwargs) + result = CrossSync.rm_aio(await table.read_modify_write_row(row_key, rules, **kwargs)) # pack results back into protobuf-parsable format if result: return result._to_dict() @@ -210,5 +219,5 @@ async def SampleRowKeys(self, request, **kwargs): app_profile_id = self.app_profile_id or request.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - result = await table.sample_row_keys(**kwargs) + result = CrossSync.rm_aio(await table.sample_row_keys(**kwargs)) return result diff --git a/test_proxy/handlers/client_handler_legacy.py b/test_proxy/handlers/client_handler_legacy.py index 400f618b5..63fe357b0 100644 --- a/test_proxy/handlers/client_handler_legacy.py +++ b/test_proxy/handlers/client_handler_legacy.py @@ -19,13 +19,13 @@ from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.cloud.bigtable.client import Client -import client_handler_data as client_handler +import client_handler_data_async as client_handler import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) -class LegacyTestProxyClientHandler(client_handler.TestProxyClientHandler): +class LegacyTestProxyClientHandler(client_handler.TestProxyClientHandlerAsync): def __init__( self, diff --git a/test_proxy/noxfile.py b/test_proxy/noxfile.py deleted file mode 100644 index bebf247b7..000000000 --- a/test_proxy/noxfile.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -import os -import pathlib -import re -from colorlog.escape_codes import parse_colors - -import nox - - -DEFAULT_PYTHON_VERSION = "3.10" - -PROXY_SERVER_PORT=os.environ.get("PROXY_SERVER_PORT", "50055") -PROXY_CLIENT_VERSION=os.environ.get("PROXY_CLIENT_VERSION", None) - -CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() -REPO_ROOT_DIRECTORY = CURRENT_DIRECTORY.parent - -nox.options.sessions = ["run_proxy", "conformance_tests"] - -TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git" -CLONE_REPO_DIR = "cloud-bigtable-clients-test" - -# Error if a python version is missing -nox.options.error_on_missing_interpreters = True - - -def default(session): - """ - if nox is run directly, run the test_proxy session - """ - test_proxy(session) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def conformance_tests(session): - """ - download and run the conformance test suite against the test proxy - """ - import subprocess - import time - # download the conformance test suite - clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR) - if not os.path.exists(clone_dir): - print("downloading copy of test repo") - session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR) - # start tests - with session.chdir(f"{clone_dir}/tests"): - session.run("go", "test", "-v", f"-proxy_addr=:{PROXY_SERVER_PORT}") - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def test_proxy(session): - """Start up the test proxy""" - # Install all dependencies, then install this package into the - # virtualenv's dist-packages. - # session.install( - # "grpcio", - # ) - if PROXY_CLIENT_VERSION is not None: - # install released version of the library - session.install(f"python-bigtable=={PROXY_CLIENT_VERSION}") - else: - # install the library from the source - session.install("-e", str(REPO_ROOT_DIRECTORY)) - session.install("-e", str(REPO_ROOT_DIRECTORY / "python-api-core")) - - session.run("python", "test_proxy.py", "--port", PROXY_SERVER_PORT, *session.posargs,) diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index 15b146b03..c2e9c6312 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -35,7 +35,8 @@ if [ ! -d "cloud-bigtable-clients-test" ]; then fi # start proxy -python test_proxy.py --port $PROXY_SERVER_PORT & +echo "starting with client type: $CLIENT_TYPE" +python test_proxy.py --port $PROXY_SERVER_PORT --client_type $CLIENT_TYPE & PROXY_PID=$! function finish { kill $PROXY_PID diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py index a0cf2f1f0..9e03f1e5c 100644 --- a/test_proxy/test_proxy.py +++ b/test_proxy/test_proxy.py @@ -55,7 +55,7 @@ def grpc_server_process(request_q, queue_pool, port=50055): server.wait_for_termination() -async def client_handler_process_async(request_q, queue_pool, use_legacy_client=False): +async def client_handler_process_async(request_q, queue_pool, client_type="async"): """ Defines a process that recives Bigtable requests from a grpc_server_process, and runs the request using a client library instance @@ -64,8 +64,7 @@ async def client_handler_process_async(request_q, queue_pool, use_legacy_client= import re import asyncio import warnings - import client_handler_data - import client_handler_legacy + import client_handler_data_async warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Bigtable emulator.*") def camel_to_snake(str): @@ -98,9 +97,7 @@ def format_dict(input_obj): return input_obj # Listen to requests from grpc server process - print_msg = "client_handler_process started" - if use_legacy_client: - print_msg += ", using legacy client" + print_msg = f"client_handler_process started with client_type={client_type}" print(print_msg) client_map = {} background_tasks = set() @@ -114,10 +111,11 @@ def format_dict(input_obj): client = client_map.get(client_id, None) # handle special cases for client creation and deletion if fn_name == "CreateClient": - if use_legacy_client: + if client_type == "legacy": + import client_handler_legacy client = client_handler_legacy.LegacyTestProxyClientHandler(**json_data) else: - client = client_handler_data.TestProxyClientHandler(**json_data) + client = client_handler_data_async.TestProxyClientHandlerAsync(**json_data) client_map[client_id] = client out_q.put(True) elif client is None: @@ -142,21 +140,21 @@ async def _run_fn(out_q, fn, **kwargs): await asyncio.sleep(0.01) -def client_handler_process(request_q, queue_pool, legacy_client=False): +def client_handler_process(request_q, queue_pool, client_type="async"): """ Sync entrypoint for client_handler_process_async """ import asyncio - asyncio.run(client_handler_process_async(request_q, queue_pool, legacy_client)) + asyncio.run(client_handler_process_async(request_q, queue_pool, client_type)) p = argparse.ArgumentParser() p.add_argument("--port", dest='port', default="50055") -p.add_argument('--legacy-client', dest='use_legacy', action='store_true', default=False) +p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "legacy"]) if __name__ == "__main__": port = p.parse_args().port - use_legacy_client = p.parse_args().use_legacy + client_type = p.parse_args().client_type # start and run both processes # larger pools support more concurrent requests @@ -176,7 +174,7 @@ def client_handler_process(request_q, queue_pool, legacy_client=False): ), ) proxy.start() - client_handler_process(request_q, response_queue_pool, use_legacy_client) + client_handler_process(request_q, response_queue_pool, client_type) proxy.join() else: # run proxy in forground and client in background diff --git a/tests/system/data/__init__.py b/tests/system/data/__init__.py index 89a37dc92..f2952b2cd 100644 --- a/tests/system/data/__init__.py +++ b/tests/system/data/__init__.py @@ -13,3 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +TEST_FAMILY = "test-family" +TEST_FAMILY_2 = "test-family-2" diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index 77086b7f3..3b5a0af06 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -17,20 +17,10 @@ """ import pytest -import pytest_asyncio import os -import asyncio import uuid -@pytest.fixture(scope="session") -def event_loop(): - loop = asyncio.get_event_loop() - yield loop - loop.stop() - loop.close() - - @pytest.fixture(scope="session") def admin_client(): """ @@ -150,22 +140,7 @@ def table_id( print(f"Table {init_table_id} not found, skipping deletion") -@pytest_asyncio.fixture(scope="session") -async def client(): - from google.cloud.bigtable.data import BigtableDataClientAsync - - project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with BigtableDataClientAsync(project=project, pool_size=4) as client: - yield client - - @pytest.fixture(scope="session") def project_id(client): """Returns the project ID from the client.""" yield client.project - - -@pytest_asyncio.fixture(scope="session") -async def table(client, table_id, instance_id): - async with client.get_table(instance_id, table_id) as table: - yield table diff --git a/tests/system/data/test_execute_query_async.py b/tests/system/data/test_execute_query_async.py deleted file mode 100644 index 489dfeab6..000000000 --- a/tests/system/data/test_execute_query_async.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -import os -from unittest import mock -from .test_execute_query_utils import ( - ChannelMockAsync, - response_with_metadata, - response_with_result, -) -from google.api_core import exceptions as core_exceptions -from google.cloud.bigtable.data import BigtableDataClientAsync - -TABLE_NAME = "TABLE_NAME" -INSTANCE_NAME = "INSTANCE_NAME" - - -class TestAsyncExecuteQuery: - @pytest.fixture() - def async_channel_mock(self): - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - yield ChannelMockAsync() - - @pytest.fixture() - def async_client(self, async_channel_mock): - with mock.patch.dict( - os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"} - ), mock.patch("grpc.aio.insecure_channel", return_value=async_channel_mock): - yield BigtableDataClientAsync() - - @pytest.mark.asyncio - async def test_execute_query(self, async_client, async_channel_mock): - values = [ - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert results[0]["a"] == "test" - assert results[0]["b"] == 8 - assert results[1]["a"] == "test2" - assert results[1]["b"] == 9 - assert results[2]["a"] == "test3" - assert results[2]["b"] is None - assert len(async_channel_mock.execute_query_calls) == 1 - - @pytest.mark.asyncio - async def test_execute_query_with_params(self, async_client, async_channel_mock): - values = [ - response_with_metadata(), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME} WHERE b=@b", - INSTANCE_NAME, - parameters={"b": 9}, - ) - results = [r async for r in result] - assert len(results) == 1 - assert results[0]["a"] == "test2" - assert results[0]["b"] == 9 - assert len(async_channel_mock.execute_query_calls) == 1 - - @pytest.mark.asyncio - async def test_execute_query_error_before_metadata( - self, async_client, async_channel_mock - ): - from google.api_core.exceptions import DeadlineExceeded - - values = [ - DeadlineExceeded(""), - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert len(results) == 3 - assert len(async_channel_mock.execute_query_calls) == 2 - - @pytest.mark.asyncio - async def test_execute_query_error_after_metadata( - self, async_client, async_channel_mock - ): - from google.api_core.exceptions import DeadlineExceeded - - values = [ - response_with_metadata(), - DeadlineExceeded(""), - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert len(results) == 3 - assert len(async_channel_mock.execute_query_calls) == 2 - assert async_channel_mock.resume_tokens == [] - - @pytest.mark.asyncio - async def test_execute_query_with_retries(self, async_client, async_channel_mock): - from google.api_core.exceptions import DeadlineExceeded - - values = [ - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - DeadlineExceeded(""), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - DeadlineExceeded(""), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert results[0]["a"] == "test" - assert results[0]["b"] == 8 - assert results[1]["a"] == "test2" - assert results[1]["b"] == 9 - assert results[2]["a"] == "test3" - assert results[2]["b"] is None - assert len(async_channel_mock.execute_query_calls) == 3 - assert async_channel_mock.resume_tokens == [b"r1", b"r2"] - - @pytest.mark.parametrize( - "exception", - [ - (core_exceptions.DeadlineExceeded("")), - (core_exceptions.Aborted("")), - (core_exceptions.ServiceUnavailable("")), - ], - ) - @pytest.mark.asyncio - async def test_execute_query_retryable_error( - self, async_client, async_channel_mock, exception - ): - values = [ - response_with_metadata(), - response_with_result("test", resume_token=b"t1"), - exception, - response_with_result(8, resume_token=b"t2"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert len(results) == 1 - assert len(async_channel_mock.execute_query_calls) == 2 - assert async_channel_mock.resume_tokens == [b"t1"] - - @pytest.mark.asyncio - async def test_execute_query_retry_partial_row( - self, async_client, async_channel_mock - ): - values = [ - response_with_metadata(), - response_with_result("test", resume_token=b"t1"), - core_exceptions.DeadlineExceeded(""), - response_with_result(8, resume_token=b"t2"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert results[0]["a"] == "test" - assert results[0]["b"] == 8 - assert len(async_channel_mock.execute_query_calls) == 2 - assert async_channel_mock.resume_tokens == [b"t1"] - - @pytest.mark.parametrize( - "ExceptionType", - [ - (core_exceptions.InvalidArgument), - (core_exceptions.FailedPrecondition), - (core_exceptions.PermissionDenied), - (core_exceptions.MethodNotImplemented), - (core_exceptions.Cancelled), - (core_exceptions.AlreadyExists), - (core_exceptions.OutOfRange), - (core_exceptions.DataLoss), - (core_exceptions.Unauthenticated), - (core_exceptions.NotFound), - (core_exceptions.ResourceExhausted), - (core_exceptions.Unknown), - (core_exceptions.InternalServerError), - ], - ) - @pytest.mark.asyncio - async def test_execute_query_non_retryable( - self, async_client, async_channel_mock, ExceptionType - ): - values = [ - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - ExceptionType(""), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - r = await result.__anext__() - assert r["a"] == "test" - assert r["b"] == 8 - - with pytest.raises(ExceptionType): - r = await result.__anext__() - - assert len(async_channel_mock.execute_query_calls) == 1 - assert async_channel_mock.resume_tokens == [] - - @pytest.mark.asyncio - async def test_execute_query_metadata_received_multiple_times_detected( - self, async_client, async_channel_mock - ): - values = [ - response_with_metadata(), - response_with_metadata(), - ] - async_channel_mock.set_values(values) - - with pytest.raises(Exception, match="Invalid ExecuteQuery response received"): - [ - r - async for r in await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - ] diff --git a/tests/system/data/test_execute_query_utils.py b/tests/system/data/test_execute_query_utils.py deleted file mode 100644 index 3439e04d2..000000000 --- a/tests/system/data/test_execute_query_utils.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest import mock - -from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse -from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue -import grpc.aio - - -try: - # async mock for python3.7-10 - from asyncio import coroutine - - def async_mock(return_value=None): - coro = mock.Mock(name="CoroutineResult") - corofunc = mock.Mock(name="CoroutineFunction", side_effect=coroutine(coro)) - corofunc.coro = coro - corofunc.coro.return_value = return_value - return corofunc - -except ImportError: - # async mock for python3.11 or later - from unittest.mock import AsyncMock - - def async_mock(return_value=None): - return AsyncMock(return_value=return_value) - - -# ExecuteQueryResponse( -# metadata={ -# "proto_schema": { -# "columns": [ -# {"name": "test1", "type_": TYPE_INT}, -# {"name": "test2", "type_": TYPE_INT}, -# ] -# } -# } -# ), -# ExecuteQueryResponse( -# results={"proto_rows_batch": {"batch_data": messages[0]}} -# ), - - -def response_with_metadata(): - schema = {"a": "string_type", "b": "int64_type"} - return ExecuteQueryResponse( - { - "metadata": { - "proto_schema": { - "columns": [ - {"name": name, "type_": {_type: {}}} - for name, _type in schema.items() - ] - } - } - } - ) - - -def response_with_result(*args, resume_token=None): - if resume_token is None: - resume_token_dict = {} - else: - resume_token_dict = {"resume_token": resume_token} - - values = [] - for column_value in args: - if column_value is None: - pb_value = PBValue({}) - else: - pb_value = PBValue( - { - "int_value" - if isinstance(column_value, int) - else "string_value": column_value - } - ) - values.append(pb_value) - rows = ProtoRows(values=values) - - return ExecuteQueryResponse( - { - "results": { - "proto_rows_batch": { - "batch_data": ProtoRows.serialize(rows), - }, - **resume_token_dict, - } - } - ) - - -class ExecuteQueryStreamMock: - def __init__(self, parent): - self.parent = parent - self.iter = iter(self.parent.values) - - def __call__(self, *args, **kwargs): - request = args[0] - - self.parent.execute_query_calls.append(request) - if request.resume_token: - self.parent.resume_tokens.append(request.resume_token) - - def stream(): - for value in self.iter: - if isinstance(value, Exception): - raise value - else: - yield value - - return stream() - - -class ChannelMock: - def __init__(self): - self.execute_query_calls = [] - self.values = [] - self.resume_tokens = [] - - def set_values(self, values): - self.values = values - - def unary_unary(self, *args, **kwargs): - return mock.MagicMock() - - def unary_stream(self, *args, **kwargs): - if args[0] == "/google.bigtable.v2.Bigtable/ExecuteQuery": - return ExecuteQueryStreamMock(self) - return mock.MagicMock() - - -class ChannelMockAsync(grpc.aio.Channel, mock.MagicMock): - def __init__(self, *args, **kwargs): - mock.MagicMock.__init__(self, *args, **kwargs) - self.execute_query_calls = [] - self.values = [] - self.resume_tokens = [] - self._iter = [] - - def get_async_get(self, *args, **kwargs): - return self.async_gen - - def set_values(self, values): - self.values = values - self._iter = iter(self.values) - - def unary_unary(self, *args, **kwargs): - return async_mock() - - def unary_stream(self, *args, **kwargs): - if args[0] == "/google.bigtable.v2.Bigtable/ExecuteQuery": - - async def async_gen(*args, **kwargs): - for value in self._iter: - yield value - - iter = async_gen() - - class UnaryStreamCallMock(grpc.aio.UnaryStreamCall): - def __aiter__(self): - async def _impl(*args, **kwargs): - try: - while True: - yield await self.read() - except StopAsyncIteration: - pass - - return _impl() - - async def read(self): - value = await iter.__anext__() - if isinstance(value, Exception): - raise value - return value - - def add_done_callback(*args, **kwargs): - pass - - def cancel(*args, **kwargs): - pass - - def cancelled(*args, **kwargs): - pass - - def code(*args, **kwargs): - pass - - def details(*args, **kwargs): - pass - - def done(*args, **kwargs): - pass - - def initial_metadata(*args, **kwargs): - pass - - def time_remaining(*args, **kwargs): - pass - - def trailing_metadata(*args, **kwargs): - pass - - async def wait_for_connection(*args, **kwargs): - return async_mock() - - class UnaryStreamMultiCallableMock(grpc.aio.UnaryStreamMultiCallable): - def __init__(self, parent): - self.parent = parent - - def __call__( - self, - request, - *, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None - ): - self.parent.execute_query_calls.append(request) - if request.resume_token: - self.parent.resume_tokens.append(request.resume_token) - return UnaryStreamCallMock() - - def add_done_callback(*args, **kwargs): - pass - - def cancel(*args, **kwargs): - pass - - def cancelled(*args, **kwargs): - pass - - def code(*args, **kwargs): - pass - - def details(*args, **kwargs): - pass - - def done(*args, **kwargs): - pass - - def initial_metadata(*args, **kwargs): - pass - - def time_remaining(*args, **kwargs): - pass - - def trailing_metadata(*args, **kwargs): - pass - - def wait_for_connection(*args, **kwargs): - pass - - # unary_stream should return https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.UnaryStreamMultiCallable - # PTAL https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.Channel.unary_stream - return UnaryStreamMultiCallableMock(self) - return async_mock() - - def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: - raise NotImplementedError() - - def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: - raise NotImplementedError() - - async def close(self, grace=None): - return - - async def channel_ready(self): - return - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: - raise NotImplementedError() - - async def wait_for_state_change(self, last_observed_state): - raise NotImplementedError() diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py deleted file mode 100644 index 8f31827ed..000000000 --- a/tests/system/data/test_system.py +++ /dev/null @@ -1,937 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import pytest_asyncio -import asyncio -import uuid -import os -from google.api_core import retry -from google.api_core.exceptions import ClientError - -from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE -from google.cloud.environment_vars import BIGTABLE_EMULATOR - -TEST_FAMILY = "test-family" -TEST_FAMILY_2 = "test-family-2" - - -@pytest.fixture(scope="session") -def column_family_config(): - """ - specify column families to create when creating a new test table - """ - from google.cloud.bigtable_admin_v2 import types - - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} - - -@pytest.fixture(scope="session") -def init_table_id(): - """ - The table_id to use when creating a new test table - """ - return f"test-table-{uuid.uuid4().hex}" - - -@pytest.fixture(scope="session") -def cluster_config(project_id): - """ - Configuration for the clusters to use when creating a new instance - """ - from google.cloud.bigtable_admin_v2 import types - - cluster = { - "test-cluster": types.Cluster( - location=f"projects/{project_id}/locations/us-central1-b", - serve_nodes=1, - ) - } - return cluster - - -class TempRowBuilder: - """ - Used to add rows to a table for testing purposes. - """ - - def __init__(self, table): - self.rows = [] - self.table = table - - async def add_row( - self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" - ): - if isinstance(value, str): - value = value.encode("utf-8") - elif isinstance(value, int): - value = value.to_bytes(8, byteorder="big", signed=True) - request = { - "table_name": self.table.table_name, - "row_key": row_key, - "mutations": [ - { - "set_cell": { - "family_name": family, - "column_qualifier": qualifier, - "value": value, - } - } - ], - } - await self.table.client._gapic_client.mutate_row(request) - self.rows.append(row_key) - - async def delete_rows(self): - if self.rows: - request = { - "table_name": self.table.table_name, - "entries": [ - {"row_key": row, "mutations": [{"delete_from_row": {}}]} - for row in self.rows - ], - } - await self.table.client._gapic_client.mutate_rows(request) - - -@pytest.mark.usefixtures("table") -async def _retrieve_cell_value(table, row_key): - """ - Helper to read an individual row - """ - from google.cloud.bigtable.data import ReadRowsQuery - - row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) - assert len(row_list) == 1 - row = row_list[0] - cell = row.cells[0] - return cell.value - - -async def _create_row_and_mutation( - table, temp_rows, *, start_value=b"start", new_value=b"new_value" -): - """ - Helper to create a new row, and a sample set_cell mutation to change its value - """ - from google.cloud.bigtable.data.mutations import SetCell - - row_key = uuid.uuid4().hex.encode() - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value - ) - # ensure cell is initialized - assert (await _retrieve_cell_value(table, row_key)) == start_value - - mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) - return row_key, mutation - - -@pytest_asyncio.fixture(scope="function") -async def temp_rows(table): - builder = TempRowBuilder(table) - yield builder - await builder.delete_rows() - - -@pytest.mark.usefixtures("table") -@pytest.mark.usefixtures("client") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10) -@pytest.mark.asyncio -async def test_ping_and_warm_gapic(client, table): - """ - Simple ping rpc test - This test ensures channels are able to authenticate with backend - """ - request = {"name": table.instance_name} - await client._gapic_client.ping_and_warm(request) - - -@pytest.mark.usefixtures("table") -@pytest.mark.usefixtures("client") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_ping_and_warm(client, table): - """ - Test ping and warm from handwritten client - """ - results = await client._ping_and_warm_instances() - assert len(results) == 1 - assert results[0] is None - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -async def test_mutation_set_cell(table, temp_rows): - """ - Ensure cells can be set properly - """ - row_key = b"bulk_mutate" - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - await table.mutate_row(row_key, mutation) - - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" -) -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_sample_row_keys(client, table, temp_rows, column_split_config): - """ - Sample keys should return a single sample in small test tables - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - results = await table.sample_row_keys() - assert len(results) == len(column_split_config) + 1 - # first keys should match the split config - for idx in range(len(column_split_config)): - assert results[idx][0] == column_split_config[idx] - assert isinstance(results[idx][1], int) - # last sample should be empty key - assert results[-1][0] == b"" - assert isinstance(results[-1][1], int) - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_bulk_mutations_set_cell(client, table, temp_rows): - """ - Ensure cells can be set properly - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - - await table.bulk_mutate_rows([bulk_mutation]) - - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.asyncio -async def test_bulk_mutations_raise_exception(client, table): - """ - If an invalid mutation is passed, an exception should be raised - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - row_key = uuid.uuid4().hex.encode() - mutation = SetCell(family="nonexistent", qualifier=b"test-qualifier", new_value=b"") - bulk_mutation = RowMutationEntry(row_key, [mutation]) - - with pytest.raises(MutationsExceptionGroup) as exc: - await table.bulk_mutate_rows([bulk_mutation]) - assert len(exc.value.exceptions) == 1 - entry_error = exc.value.exceptions[0] - assert isinstance(entry_error, FailedMutationEntryError) - assert entry_error.index == 0 - assert entry_error.entry == bulk_mutation - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_context_manager(client, table, temp_rows): - """ - test batcher with context manager. Should flush on exit - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher() as batcher: - await batcher.append(bulk_mutation) - await batcher.append(bulk_mutation2) - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert len(batcher._staged_entries) == 0 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_timer_flush(client, table, temp_rows): - """ - batch should occur after flush_interval seconds - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - flush_interval = 0.1 - async with table.mutations_batcher(flush_interval=flush_interval) as batcher: - await batcher.append(bulk_mutation) - await asyncio.sleep(0) - assert len(batcher._staged_entries) == 1 - await asyncio.sleep(flush_interval + 0.1) - assert len(batcher._staged_entries) == 0 - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_count_flush(client, table, temp_rows): - """ - batch should flush after flush_limit_mutation_count mutations - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - # should be noop; flush not scheduled - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - await asyncio.gather(*batcher._flush_jobs) - assert len(batcher._staged_entries) == 0 - assert len(batcher._flush_jobs) == 0 - # ensure cells were updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert (await _retrieve_cell_value(table, row_key2)) == new_value2 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_bytes_flush(client, table, temp_rows): - """ - batch should flush after flush_limit_bytes bytes - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 - - async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - assert len(batcher._staged_entries) == 0 - # let flush complete - await asyncio.gather(*batcher._flush_jobs) - # ensure cells were updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert (await _retrieve_cell_value(table, row_key2)) == new_value2 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_mutations_batcher_no_flush(client, table, temp_rows): - """ - test with no flush requirements met - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - start_value = b"unchanged" - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - async with table.mutations_batcher( - flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 - ) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # flush not scheduled - assert len(batcher._flush_jobs) == 0 - await asyncio.sleep(0.01) - assert len(batcher._staged_entries) == 2 - assert len(batcher._flush_jobs) == 0 - # ensure cells were not updated - assert (await _retrieve_cell_value(table, row_key)) == start_value - assert (await _retrieve_cell_value(table, row_key2)) == start_value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start,increment,expected", - [ - (0, 0, 0), - (0, 1, 1), - (0, -1, -1), - (1, 0, 1), - (0, -100, -100), - (0, 3000, 3000), - (10, 4, 14), - (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), - (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), - (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), - ], -) -@pytest.mark.asyncio -async def test_read_modify_write_row_increment( - client, table, temp_rows, start, increment, expected -): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - - rule = IncrementRule(family, qualifier, increment) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert int(result[0]) == expected - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start,append,expected", - [ - (b"", b"", b""), - ("", "", b""), - (b"abc", b"123", b"abc123"), - (b"abc", "123", b"abc123"), - ("", b"1", b"1"), - (b"abc", "", b"abc"), - (b"hello", b"world", b"helloworld"), - ], -) -@pytest.mark.asyncio -async def test_read_modify_write_row_append( - client, table, temp_rows, start, append, expected -): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - - rule = AppendValueRule(family, qualifier, append) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert result[0].value == expected - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_modify_write_row_chained(client, table, temp_rows): - """ - test read_modify_write_row with multiple rules - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - start_amount = 1 - increment_amount = 10 - await temp_rows.add_row( - row_key, value=start_amount, family=family, qualifier=qualifier - ) - rule = [ - IncrementRule(family, qualifier, increment_amount), - AppendValueRule(family, qualifier, "hello"), - AppendValueRule(family, qualifier, "world"), - AppendValueRule(family, qualifier, "!"), - ] - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert result[0].family == family - assert result[0].qualifier == qualifier - # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values - assert ( - result[0].value - == (start_amount + increment_amount).to_bytes(8, "big", signed=True) - + b"helloworld!" - ) - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start_val,predicate_range,expected_result", - [ - (1, (0, 2), True), - (-1, (0, 2), False), - ], -) -@pytest.mark.asyncio -async def test_check_and_mutate( - client, table, temp_rows, start_val, predicate_range, expected_result -): - """ - test that check_and_mutate_row works applies the right mutations, and returns the right result - """ - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable.data.row_filters import ValueRangeFilter - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - - await temp_rows.add_row( - row_key, value=start_val, family=family, qualifier=qualifier - ) - - false_mutation_value = b"false-mutation-value" - false_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value - ) - true_mutation_value = b"true-mutation-value" - true_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value - ) - predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = await table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - assert result == expected_result - # ensure cell is updated - expected_value = true_mutation_value if expected_result else false_mutation_value - assert (await _retrieve_cell_value(table, row_key)) == expected_value - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_check_and_mutate_empty_request(client, table): - """ - check_and_mutate with no true or fale mutations should raise an error - """ - from google.api_core import exceptions - - with pytest.raises(exceptions.InvalidArgument) as e: - await table.check_and_mutate_row( - b"row_key", None, true_case_mutations=None, false_case_mutations=None - ) - assert "No mutations provided" in str(e.value) - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_stream(table, temp_rows): - """ - Ensure that the read_rows_stream method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - # full table scan - generator = await table.read_rows_stream({}) - first_row = await generator.__anext__() - second_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - assert second_row.row_key == b"row_key_2" - with pytest.raises(StopAsyncIteration): - await generator.__anext__() - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows(table, temp_rows): - """ - Ensure that the read_rows method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - row_list = await table.read_rows({}) - assert len(row_list) == 2 - assert row_list[0].row_key == b"row_key_1" - assert row_list[1].row_key == b"row_key_2" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_simple(table, temp_rows): - """ - Test read rows sharded with two queries - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 4 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"b" - assert row_list[3].row_key == b"d" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_from_sample(table, temp_rows): - """ - Test end-to-end sharding - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.read_rows_query import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) - shard_queries = query.shard(table_shard_keys) - row_list = await table.read_rows_sharded(shard_queries) - assert len(row_list) == 3 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"d" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_filters_limits(table, temp_rows): - """ - Test read rows sharded with filters and limits - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - label_filter1 = ApplyLabelFilter("first") - label_filter2 = ApplyLabelFilter("second") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 3 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"b" - assert row_list[2].row_key == b"d" - assert row_list[0][0].labels == ["first"] - assert row_list[1][0].labels == ["second"] - assert row_list[2][0].labels == ["second"] - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_range_query(table, temp_rows): - """ - Ensure that the read_rows method works - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # full table scan - query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_single_key_query(table, temp_rows): - """ - Ensure that the read_rows method works with specified query - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve specific keys - query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_with_filter(table, temp_rows): - """ - ensure filters are applied - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve keys with filter - expected_label = "test-label" - row_filter = ApplyLabelFilter(expected_label) - query = ReadRowsQuery(row_filter=row_filter) - row_list = await table.read_rows(query) - assert len(row_list) == 4 - for row in row_list: - assert row[0].labels == [expected_label] - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_rows_stream_close(table, temp_rows): - """ - Ensure that the read_rows_stream can be closed - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - query = ReadRowsQuery() - generator = await table.read_rows_stream(query) - # grab first row - first_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - # close stream early - await generator.aclose() - with pytest.raises(StopAsyncIteration): - await generator.__anext__() - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row(table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - - await temp_rows.add_row(b"row_key_1", value=b"value") - row = await table.read_row(b"row_key_1") - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row_missing(table): - """ - Test read_row when row does not exist - """ - from google.api_core import exceptions - - row_key = "row_key_not_exist" - result = await table.read_row(row_key) - assert result is None - with pytest.raises(exceptions.InvalidArgument) as e: - await table.read_row("") - assert "Row keys must be non-empty" in str(e) - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row_w_filter(table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"row_key_1", value=b"value") - expected_label = "test-label" - label_filter = ApplyLabelFilter(expected_label) - row = await table.read_row(b"row_key_1", row_filter=label_filter) - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - assert row.cells[0].labels == [expected_label] - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_row_exists(table, temp_rows): - from google.api_core import exceptions - - """Test row_exists with rows that exist and don't exist""" - assert await table.row_exists(b"row_key_1") is False - await temp_rows.add_row(b"row_key_1") - assert await table.row_exists(b"row_key_1") is True - assert await table.row_exists("row_key_1") is True - assert await table.row_exists(b"row_key_2") is False - assert await table.row_exists("row_key_2") is False - assert await table.row_exists("3") is False - await temp_rows.add_row(b"3") - assert await table.row_exists(b"3") is True - with pytest.raises(exceptions.InvalidArgument) as e: - await table.row_exists("") - assert "Row keys must be non-empty" in str(e) - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.parametrize( - "cell_value,filter_input,expect_match", - [ - (b"abc", b"abc", True), - (b"abc", "abc", True), - (b".", ".", True), - (".*", ".*", True), - (".*", b".*", True), - ("a", ".*", False), - (b".*", b".*", True), - (r"\a", r"\a", True), - (b"\xe2\x98\x83", "☃", True), - ("☃", "☃", True), - (r"\C☃", r"\C☃", True), - (1, 1, True), - (2, 1, False), - (68, 68, True), - ("D", 68, False), - (68, "D", False), - (-1, -1, True), - (2852126720, 2852126720, True), - (-1431655766, -1431655766, True), - (-1431655766, -1, False), - ], -) -@pytest.mark.asyncio -async def test_literal_value_filter( - table, temp_rows, cell_value, filter_input, expect_match -): - """ - Literal value filter does complex escaping on re2 strings. - Make sure inputs are properly interpreted by the server - """ - from google.cloud.bigtable.data.row_filters import LiteralValueFilter - from google.cloud.bigtable.data import ReadRowsQuery - - f = LiteralValueFilter(filter_input) - await temp_rows.add_row(b"row_key_1", value=cell_value) - query = ReadRowsQuery(row_filter=f) - row_list = await table.read_rows(query) - assert len(row_list) == bool( - expect_match - ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py new file mode 100644 index 000000000..c0e9f39d2 --- /dev/null +++ b/tests/system/data/test_system_async.py @@ -0,0 +1,1016 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import asyncio +import uuid +import os +from google.api_core import retry +from google.api_core.exceptions import ClientError + +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR + +from google.cloud.bigtable.data._cross_sync import CrossSync + +from . import TEST_FAMILY, TEST_FAMILY_2 + + +__CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system_autogen" + + +@CrossSync.convert_class( + sync_name="TempRowBuilder", + add_mapping_for_name="TempRowBuilder", +) +class TempRowBuilderAsync: + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + @CrossSync.convert + async def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + await self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + @CrossSync.convert + async def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + await self.table.client._gapic_client.mutate_rows(request) + + +@CrossSync.convert_class(sync_name="TestSystem") +class TestSystemAsync: + @CrossSync.convert + @CrossSync.pytest_fixture(scope="session") + async def client(self): + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + async with CrossSync.DataClient(project=project) as client: + yield client + + @CrossSync.convert + @CrossSync.pytest_fixture(scope="session") + async def table(self, client, table_id, instance_id): + async with client.get_table(instance_id, table_id) as table: + yield table + + @CrossSync.drop + @pytest.fixture(scope="session") + def event_loop(self): + loop = asyncio.get_event_loop() + yield loop + loop.stop() + loop.close() + + @pytest.fixture(scope="session") + def column_family_config(self): + """ + specify column families to create when creating a new test table + """ + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + @pytest.fixture(scope="session") + def init_table_id(self): + """ + The table_id to use when creating a new test table + """ + return f"test-table-{uuid.uuid4().hex}" + + @pytest.fixture(scope="session") + def cluster_config(self, project_id): + """ + Configuration for the clusters to use when creating a new instance + """ + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", + serve_nodes=1, + ) + } + return cluster + + @CrossSync.convert + @pytest.mark.usefixtures("table") + async def _retrieve_cell_value(self, table, row_key): + """ + Helper to read an individual row + """ + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + @CrossSync.convert + async def _create_row_and_mutation( + self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" + ): + """ + Helper to create a new row, and a sample set_cell mutation to change its value + """ + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + # ensure cell is initialized + assert await self._retrieve_cell_value(table, row_key) == start_value + + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return row_key, mutation + + @CrossSync.convert + @CrossSync.pytest_fixture(scope="function") + async def temp_rows(self, table): + builder = CrossSync.TempRowBuilder(table) + yield builder + await builder.delete_rows() + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 + ) + @CrossSync.pytest + async def test_ping_and_warm_gapic(self, client, table): + """ + Simple ping rpc test + This test ensures channels are able to authenticate with backend + """ + request = {"name": table.instance_name} + await client._gapic_client.ping_and_warm(request) + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_ping_and_warm(self, client, table): + """ + Test ping and warm from handwritten client + """ + results = await client._ping_and_warm_instances() + assert len(results) == 1 + assert results[0] is None + + @CrossSync.pytest + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + async def test_mutation_set_cell(self, table, temp_rows): + """ + Ensure cells can be set properly + """ + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + await table.mutate_row(row_key, mutation) + + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + """ + Sample keys should return a single sample in small test tables + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + results = await table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + # first keys should match the split config + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + # last sample should be empty key + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_bulk_mutations_set_cell(self, client, table, temp_rows): + """ + Ensure cells can be set properly + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + await table.bulk_mutate_rows([bulk_mutation]) + + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @CrossSync.pytest + async def test_bulk_mutations_raise_exception(self, client, table): + """ + If an invalid mutation is passed, an exception should be raised + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell( + family="nonexistent", qualifier=b"test-qualifier", new_value=b"" + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + with pytest.raises(MutationsExceptionGroup) as exc: + await table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_context_manager(self, client, table, temp_rows): + """ + test batcher with context manager. Should flush on exit + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher() as batcher: + await batcher.append(bulk_mutation) + await batcher.append(bulk_mutation2) + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + """ + batch should occur after flush_interval seconds + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + async with table.mutations_batcher(flush_interval=flush_interval) as batcher: + await batcher.append(bulk_mutation) + await CrossSync.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + await CrossSync.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_count_flush(self, client, table, temp_rows): + """ + batch should flush after flush_limit_mutation_count mutations + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + # should be noop; flush not scheduled + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + # let flush complete + for future in list(batcher._flush_jobs): + await future + # for sync version: grab result + future.result() + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + # ensure cells were updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + """ + batch should flush after flush_limit_bytes bytes + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + + async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + # let flush complete + for future in list(batcher._flush_jobs): + await future + # for sync version: grab result + future.result() + # ensure cells were updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_mutations_batcher_no_flush(self, client, table, temp_rows): + """ + test with no flush requirements met + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + async with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # flush not scheduled + assert len(batcher._flush_jobs) == 0 + await CrossSync.yield_to_event_loop() + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + # ensure cells were not updated + assert (await self._retrieve_cell_value(table, row_key)) == start_value + assert (await self._retrieve_cell_value(table, row_key2)) == start_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_large_batch(self, client, table, temp_rows): + """ + test batcher with large batch of mutations + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + + add_mutation = SetCell( + family=TEST_FAMILY, qualifier=b"test-qualifier", new_value=b"a" + ) + row_mutations = [] + for i in range(50_000): + row_key = uuid.uuid4().hex.encode() + row_mutations.append(RowMutationEntry(row_key, [add_mutation])) + # append row key for eventual deletion + temp_rows.rows.append(row_key) + + async with table.mutations_batcher() as batcher: + for mutation in row_mutations: + await batcher.append(mutation) + # ensure cell is updated + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], + ) + @CrossSync.pytest + async def test_read_modify_write_row_increment( + self, client, table, temp_rows, start, increment, expected + ): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, value=start, family=family, qualifier=qualifier + ) + + rule = IncrementRule(family, qualifier, increment) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], + ) + @CrossSync.pytest + async def test_read_modify_write_row_append( + self, client, table, temp_rows, start, append, expected + ): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, value=start, family=family, qualifier=qualifier + ) + + rule = AppendValueRule(family, qualifier, append) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_modify_write_row_chained(self, client, table, temp_rows): + """ + test read_modify_write_row with multiple rules + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + await temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [ + (1, (0, 2), True), + (-1, (0, 2), False), + ], + ) + @CrossSync.pytest + async def test_check_and_mutate( + self, client, table, temp_rows, start_val, predicate_range, expected_result + ): + """ + test that check_and_mutate_row works applies the right mutations, and returns the right result + """ + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + + await temp_rows.add_row( + row_key, value=start_val, family=family, qualifier=qualifier + ) + + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + # ensure cell is updated + expected_value = ( + true_mutation_value if expected_result else false_mutation_value + ) + assert (await self._retrieve_cell_value(table, row_key)) == expected_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_check_and_mutate_empty_request(self, client, table): + """ + check_and_mutate with no true or fale mutations should raise an error + """ + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + await table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_stream(self, table, temp_rows): + """ + Ensure that the read_rows_stream method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + # full table scan + generator = await table.read_rows_stream({}) + first_row = await generator.__anext__() + second_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(CrossSync.StopIteration): + await generator.__anext__() + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows(self, table, temp_rows): + """ + Ensure that the read_rows method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + row_list = await table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_sharded_simple(self, table, temp_rows): + """ + Test read rows sharded with two queries + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_sharded_from_sample(self, table, temp_rows): + """ + Test end-to-end sharding + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = await table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_sharded_filters_limits(self, table, temp_rows): + """ + Test read rows sharded with filters and limits + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_range_query(self, table, temp_rows): + """ + Ensure that the read_rows method works + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # full table scan + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_single_key_query(self, table, temp_rows): + """ + Ensure that the read_rows method works with specified query + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve specific keys + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_with_filter(self, table, temp_rows): + """ + ensure filters are applied + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve keys with filter + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = await table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) + @CrossSync.pytest + async def test_read_rows_stream_close(self, table, temp_rows): + """ + Ensure that the read_rows_stream can be closed + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + query = ReadRowsQuery() + generator = await table.read_rows_stream(query) + # grab first row + first_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + # close stream early + await generator.aclose() + with pytest.raises(CrossSync.StopIteration): + await generator.__anext__() + + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_row(self, table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + + await temp_rows.add_row(b"row_key_1", value=b"value") + row = await table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_row_missing(self, table): + """ + Test read_row when row does not exist + """ + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = await table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + await table.read_row("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_row_w_filter(self, table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = await table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_row_exists(self, table, temp_rows): + from google.api_core import exceptions + + """Test row_exists with rows that exist and don't exist""" + assert await table.row_exists(b"row_key_1") is False + await temp_rows.add_row(b"row_key_1") + assert await table.row_exists(b"row_key_1") is True + assert await table.row_exists("row_key_1") is True + assert await table.row_exists(b"row_key_2") is False + assert await table.row_exists("row_key_2") is False + assert await table.row_exists("3") is False + await temp_rows.add_row(b"3") + assert await table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + await table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + (r"\a", r"\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + (r"\C☃", r"\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], + ) + @CrossSync.pytest + async def test_literal_value_filter( + self, table, temp_rows, cell_value, filter_input, expect_match + ): + """ + Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server + """ + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + await temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = await table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 73da1b46d..621f4d9a2 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -16,42 +16,42 @@ from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 -import google.api_core.exceptions as core_exceptions +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import Forbidden + +from google.cloud.bigtable.data._cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore - -def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test__mutate_rows" -class TestMutateRowsOperation: +@CrossSync.convert_class("TestMutateRowsOperation") +class TestMutateRowsOperationAsync: def _target_class(self): - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) - - return _MutateRowsOperationAsync + return CrossSync._MutateRowsOperation def _make_one(self, *args, **kwargs): if not args: kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", AsyncMock()) + kwargs["table"] = kwargs.pop("table", CrossSync.Mock()) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) return self._target_class()(*args, **kwargs) + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + @CrossSync.convert async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) @@ -64,7 +64,7 @@ async def _mock_stream(self, mutation_list, error_dict): ) def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = AsyncMock() + mock_fn = CrossSync.Mock() if error_dict is None: error_dict = {} mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( @@ -83,7 +83,7 @@ def test_ctor(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 attempt_timeout = 0.01 retryable_exceptions = () @@ -131,17 +131,14 @@ def test_ctor_too_many_entries(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) operation_timeout = 0.05 attempt_timeout = 0.01 - # no errors if at limit - self._make_one(client, table, entries, operation_timeout, attempt_timeout) - # raise error after crossing with pytest.raises(ValueError) as e: self._make_one( client, table, - entries + [_make_mutation()], + entries, operation_timeout, attempt_timeout, ) @@ -150,18 +147,18 @@ def test_ctor_too_many_entries(self): ) assert "Found 100001" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_operation(self): """ Test successful case of mutate_rows_operation """ client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 cls = self._target_class() with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() + f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync.Mock() ) as attempt_mock: instance = self._make_one( client, table, entries, operation_timeout, operation_timeout @@ -169,17 +166,15 @@ async def test_mutate_rows_operation(self): await instance.start() assert attempt_mock.call_count == 1 - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - @pytest.mark.asyncio + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + @CrossSync.pytest async def test_mutate_rows_attempt_exception(self, exc_type): """ exceptions raised from attempt should be raised in MutationsExceptionGroup """ - client = AsyncMock() + client = CrossSync.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_exception = exc_type("test") client.mutate_rows.side_effect = expected_exception @@ -197,10 +192,8 @@ async def test_mutate_rows_attempt_exception(self, exc_type): assert len(instance.errors) == 2 assert len(instance.remaining_indices) == 0 - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - @pytest.mark.asyncio + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + @CrossSync.pytest async def test_mutate_rows_exception(self, exc_type): """ exceptions raised from retryable should be raised in MutationsExceptionGroup @@ -210,13 +203,13 @@ async def test_mutate_rows_exception(self, exc_type): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_cause = exc_type("abort") with mock.patch.object( self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = expected_cause found_exc = None @@ -236,27 +229,24 @@ async def test_mutate_rows_exception(self, exc_type): @pytest.mark.parametrize( "exc_type", - [core_exceptions.DeadlineExceeded, RuntimeError], + [DeadlineExceeded, RuntimeError], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): """ If an exception fails but eventually passes, it should not raise an exception """ - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] + entries = [self._make_mutation()] operation_timeout = 1 expected_cause = exc_type("retry") num_retries = 2 with mock.patch.object( - _MutateRowsOperationAsync, + self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = [expected_cause] * num_retries + [None] instance = self._make_one( @@ -270,7 +260,7 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): await instance.start() assert attempt_mock.call_count == num_retries + 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_incomplete_ignored(self): """ MutateRowsIncomplete exceptions should not be added to error list @@ -281,12 +271,12 @@ async def test_mutate_rows_incomplete_ignored(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] + entries = [self._make_mutation()] operation_timeout = 0.05 with mock.patch.object( self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = _MutateRowsIncomplete("ignored") found_exc = None @@ -301,10 +291,10 @@ async def test_mutate_rows_incomplete_ignored(self): assert len(found_exc.exceptions) == 1 assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_single_entry_success(self): """Test mutating a single entry""" - mutation = _make_mutation() + mutation = self._make_mutation() expected_timeout = 1.3 mock_gapic_fn = self._make_mock_gapic({0: mutation}) instance = self._make_one( @@ -319,7 +309,7 @@ async def test_run_attempt_single_entry_success(self): assert kwargs["timeout"] == expected_timeout assert kwargs["entries"] == [mutation._to_pb()] - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_empty_request(self): """Calling with no mutations should result in no API calls""" mock_gapic_fn = self._make_mock_gapic([]) @@ -329,14 +319,14 @@ async def test_run_attempt_empty_request(self): await instance._run_attempt() assert mock_gapic_fn.call_count == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_partial_success_retryable(self): """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - success_mutation = _make_mutation() - success_mutation_2 = _make_mutation() - failure_mutation = _make_mutation() + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( @@ -352,12 +342,12 @@ async def test_run_attempt_partial_success_retryable(self): assert instance.errors[1][0].grpc_status_code == 300 assert 2 not in instance.errors - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_partial_success_non_retryable(self): """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" - success_mutation = _make_mutation() - success_mutation_2 = _make_mutation() - failure_mutation = _make_mutation() + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index e2b02517f..6a4583a7b 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -13,23 +13,22 @@ import pytest -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data._cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore # noqa F401 -TEST_FAMILY = "family_name" -TEST_QUALIFIER = b"qualifier" -TEST_TIMESTAMP = 123456789 -TEST_LABELS = ["label1", "label2"] +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test__read_rows" -class TestReadRowsOperation: + +@CrossSync.convert_class( + sync_name="TestReadRowsOperation", +) +class TestReadRowsOperationAsync: """ Tests helper functions in the ReadRowsOperation class in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt @@ -37,10 +36,9 @@ class TestReadRowsOperation: """ @staticmethod + @CrossSync.convert def _get_target_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - - return _ReadRowsOperationAsync + return CrossSync._ReadRowsOperation def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -60,8 +58,9 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() + subpath = "_async" if CrossSync.is_async else "_sync_autogen" with mock.patch( - "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator", + f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", time_gen_mock, ): instance = self._make_one( @@ -236,7 +235,7 @@ def test_revise_to_empty_rowset(self): (4, 2, 2), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_revise_limit(self, start_limit, emit_num, expected_limit): """ revise_limit should revise the request's limit field @@ -277,7 +276,7 @@ async def mock_stream(): assert instance._remaining_count == expected_limit @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_revise_limit_over_limit(self, start_limit, emit_num): """ Should raise runtime error if we get in state where emit_num > start_num @@ -316,7 +315,11 @@ async def mock_stream(): pass assert "emit count exceeds row limit" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert( + sync_name="test_close", + replace_symbols={"aclose": "close", "__anext__": "__next__"}, + ) async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -328,7 +331,7 @@ async def mock_stream(): yield 1 with mock.patch.object( - _ReadRowsOperationAsync, "_read_rows_attempt" + self._get_target_class(), "_read_rows_attempt" ) as mock_attempt: instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) wrapped_gen = mock_stream() @@ -337,20 +340,20 @@ async def mock_stream(): # read one row await gen.__anext__() await gen.aclose() - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await gen.__anext__() # try calling a second time await gen.aclose() # ensure close was propagated to wrapped generator - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await wrapped_gen.__anext__() - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import ReadRowsResponse @@ -375,37 +378,10 @@ async def mock_stream(): instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: await stream.__anext__() assert "row keys should be strictly increasing" in str(exc.value) - - -class MockStream(_ReadRowsOperationAsync): - """ - Mock a _ReadRowsOperationAsync stream for testing - """ - - def __init__(self, items=None, errors=None, operation_timeout=None): - self.transient_errors = errors - self.operation_timeout = operation_timeout - self.next_idx = 0 - if items is None: - items = list(range(10)) - self.items = items - - def __aiter__(self): - return self - - async def __anext__(self): - if self.next_idx >= len(self.items): - raise StopAsyncIteration - item = self.items[self.next_idx] - self.next_idx += 1 - if isinstance(item, Exception): - raise item - return item - - async def aclose(self): - pass diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index fdc86e924..c24fa3d98 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -19,6 +19,7 @@ import sys import pytest +import mock from google.cloud.bigtable.data import mutations from google.auth.credentials import AnonymousCredentials @@ -31,67 +32,71 @@ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse -# try/except added for compatibility with python < 3.8 -try: - from unittest import mock - from unittest.mock import AsyncMock # type: ignore -except ImportError: # pragma: NO COVER - import mock # type: ignore - from mock import AsyncMock # type: ignore +from google.cloud.bigtable.data._cross_sync import CrossSync -VENEER_HEADER_REGEX = re.compile( - r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-data-async gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" -) +if CrossSync.is_async: + from google.api_core import grpc_helpers_async + from google.cloud.bigtable.data._async.client import TableAsync + CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) +else: + from google.api_core import grpc_helpers + from google.cloud.bigtable.data._sync_autogen.client import Table # noqa: F401 -def _make_client(*args, use_emulator=True, **kwargs): - import os - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + CrossSync.add_mapping("grpc_helpers", grpc_helpers) - env_mask = {} - # by default, use emulator mode to avoid auth issues in CI - # emulator mode must be disabled by tests that check refresh background tasks - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - else: - # set some default values - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return BigtableDataClientAsync(*args, **kwargs) +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_client" +@CrossSync.convert_class( + sync_name="TestBigtableDataClient", + add_mapping_for_name="TestBigtableDataClient", +) class TestBigtableDataClientAsync: - def _get_target_class(self): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync - - def _make_one(self, *args, **kwargs): - return _make_client(*args, **kwargs) + @staticmethod + @CrossSync.convert + def _get_target_class(): + return CrossSync.DataClient + + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os + + env_mask = {} + # by default, use emulator mode to avoid auth issues in CI + # emulator mode must be disabled by tests that check channel pooling/refresh background tasks + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings + + warnings.filterwarnings("ignore", category=RuntimeWarning) + else: + # set some default values + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor(self): expected_project = "project-id" expected_credentials = AnonymousCredentials() - client = self._make_one( + client = self._make_client( project="project-id", credentials=expected_credentials, use_emulator=False, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert client.project == expected_project assert not client._active_instances assert client._channel_refresh_task is not None assert client.transport._credentials == expected_credentials await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_super_inits(self): - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib @@ -99,14 +104,16 @@ async def test_ctor_super_inits(self): credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + with mock.patch.object( + CrossSync.GapicClient, "__init__" + ) as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( ClientWithProject, "__init__" ) as client_project_init: client_project_init.return_value = None try: - self._make_one( + self._make_client( project=project, credentials=credentials, client_options=options_parsed, @@ -126,17 +133,16 @@ async def test_ctor_super_inits(self): assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_dict_options(self): - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.api_core.client_options import ClientOptions client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + with mock.patch.object( + CrossSync.GapicClient, "__init__" + ) as bigtable_client_init: try: - self._make_one(client_options=client_options) + self._make_client(client_options=client_options) except TypeError: pass bigtable_client_init.assert_called_once() @@ -147,17 +153,29 @@ async def test_ctor_dict_options(self): with mock.patch.object( self._get_target_class(), "_start_background_channel_refresh" ) as start_background_refresh: - client = self._make_one(client_options=client_options, use_emulator=False) + client = self._make_client( + client_options=client_options, use_emulator=False + ) start_background_refresh.assert_called_once() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_veneer_grpc_headers(self): + client_component = "data-async" if CrossSync.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-" + + client_component + + r" gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" + ) + # client_info should be populated with headers to # detect as a veneer client - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + if CrossSync.is_async: + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + else: + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") with patch as gapic_mock: - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") wrapped_call_list = gapic_mock.call_args_list assert len(wrapped_call_list) > 0 # each wrapped call should have veneer headers @@ -172,56 +190,67 @@ async def test_veneer_grpc_headers(self): ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" await client.close() + @CrossSync.drop @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) with pytest.raises(RuntimeError): client._start_background_channel_refresh() - @pytest.mark.asyncio + @CrossSync.pytest async def test__start_background_channel_refresh_task_exists(self): # if tasks exist, should do nothing - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) assert client._channel_refresh_task is not None with mock.patch.object(asyncio, "create_task") as create_task: client._start_background_channel_refresh() create_task.assert_not_called() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__start_background_channel_refresh(self): # should create background tasks for each channel - client = self._make_one(project="project-id", use_emulator=False) - ping_and_warm = AsyncMock() - client._ping_and_warm_instances = ping_and_warm - client._start_background_channel_refresh() - assert client._channel_refresh_task is not None - assert isinstance(client._channel_refresh_task, asyncio.Task) - await asyncio.sleep(0.1) - assert ping_and_warm.call_count == 1 - await client.close() + client = self._make_client(project="project-id") + with mock.patch.object( + client, "_ping_and_warm_instances", CrossSync.Mock() + ) as ping_and_warm: + client._emulator_host = None + client._start_background_channel_refresh() + assert client._channel_refresh_task is not None + assert isinstance(client._channel_refresh_task, CrossSync.Task) + await CrossSync.sleep(0.1) + assert ping_and_warm.call_count == 1 + await client.close() - @pytest.mark.asyncio + @CrossSync.drop + @CrossSync.pytest @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" ) async def test__start_background_channel_refresh_task_names(self): # if tasks exist, should do nothing - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) name = client._channel_refresh_task.get_name() - assert "BigtableDataClientAsync channel refresh" in name + assert "channel refresh" in name await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__ping_and_warm_instances(self): """ test ping and warm with mocked asyncio.gather """ client_mock = mock.Mock() - with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - # simulate gather by returning the same number of items as passed in - gather.side_effect = lambda *args, **kwargs: [None for _ in args] + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync, "gather_partials", CrossSync.Mock() + ) as gather: + # gather_partials is expected to call the function passed, and return the result + gather.side_effect = lambda partials, **kwargs: [None for _ in partials] channel = mock.Mock() # test with no instances client_mock._active_instances = [] @@ -229,10 +258,8 @@ async def test__ping_and_warm_instances(self): client_mock, channel=channel ) assert len(result) == 0 - gather.assert_called_once() - gather.assert_awaited_once() - assert not gather.call_args.args - assert gather.call_args.kwargs == {"return_exceptions": True} + assert gather.call_args[1]["return_exceptions"] is True + assert gather.call_args[1]["sync_executor"] == client_mock._executor # test with instances client_mock._active_instances = [ (mock.Mock(), mock.Mock(), mock.Mock()) @@ -244,8 +271,11 @@ async def test__ping_and_warm_instances(self): ) assert len(result) == 4 gather.assert_called_once() - gather.assert_awaited_once() - assert len(gather.call_args.args) == 4 + # expect one partial for each instance + partial_list = gather.call_args.args[0] + assert len(partial_list) == 4 + if CrossSync.is_async: + gather.assert_awaited_once() # check grpc call arguments grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): @@ -265,15 +295,21 @@ async def test__ping_and_warm_instances(self): == f"name={expected_instance}&app_profile_id={expected_app_profile}" ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__ping_and_warm_single_instance(self): """ should be able to call ping and warm with single instance """ client_mock = mock.Mock() - with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - # simulate gather by returning the same number of items as passed in - gather.side_effect = lambda *args, **kwargs: [None for _ in args] + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync, "gather_partials", CrossSync.Mock() + ) as gather: + gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 test_key = ("test-instance", "test-table", "test-app-profile") @@ -298,7 +334,7 @@ async def test__ping_and_warm_single_instance(self): metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "refresh_interval, wait_time, expected_sleep", [ @@ -316,38 +352,43 @@ async def test__manage_channel_first_sleep( # first sleep time should be `refresh_interval` seconds after client init import time - with mock.patch.object(time, "monotonic") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") client._channel_init_time = -wait_time await client._manage_channel(refresh_interval, refresh_interval) except asyncio.CancelledError: pass sleep.assert_called_once() - call_time = sleep.call_args[0][0] + call_time = sleep.call_args[0][1] assert ( abs(call_time - expected_sleep) < 0.1 ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__manage_channel_ping_and_warm(self): """ _manage channel should call ping and warm internally """ import time + import threading client_mock = mock.Mock() + client_mock._is_closed.is_set.return_value = False client_mock._channel_init_time = time.monotonic() orig_channel = client_mock.transport.grpc_channel # should ping an warm all new channels, and old channels if sleeping - with mock.patch.object(asyncio, "sleep"): + sleep_tuple = ( + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple): # stop process after close is called orig_channel.close.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() + ping_and_warm = client_mock._ping_and_warm_instances = CrossSync.Mock() # should ping and warm old channel then new if sleep > 0 try: await self._get_target_class()._manage_channel(client_mock, 10) @@ -362,7 +403,7 @@ async def test__manage_channel_ping_and_warm(self): assert orig_channel in called_with assert client_mock.transport.grpc_channel in called_with - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "refresh_interval, num_cycles, expected_sleep", [ @@ -379,46 +420,46 @@ async def test__manage_channel_sleeps( import random channel = mock.Mock() - channel.close = mock.AsyncMock() + channel.close = CrossSync.Mock() with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] - try: - client = self._make_one(project="project-id") - client.transport._grpc_channel = channel - with mock.patch.object( - client.transport, "create_channel", return_value=channel - ): + client = self._make_client(project="project-id") + client.transport._grpc_channel = channel + with mock.patch.object( + client.transport, "create_channel", CrossSync.Mock + ): + try: if refresh_interval is not None: await client._manage_channel( - refresh_interval, refresh_interval + refresh_interval, refresh_interval, grace_period=0 ) else: - await client._manage_channel() - except asyncio.CancelledError: - pass + await client._manage_channel(grace_period=0) + except asyncio.CancelledError: + pass assert sleep.call_count == num_cycles - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + total_sleep = sum([call[0][1] for call in sleep.call_args_list]) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__manage_channel_random(self): import random - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(CrossSync, "event_wait") as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -429,7 +470,7 @@ async def test__manage_channel_random(self): uniform.side_effect = lambda min_, max_: min_ sleep.side_effect = [None, asyncio.CancelledError] try: - await client._manage_channel(min_val, max_val) + await client._manage_channel(min_val, max_val, grace_period=0) except asyncio.CancelledError: pass assert uniform.call_count == 2 @@ -438,39 +479,35 @@ async def test__manage_channel_random(self): assert found_min == min_val assert found_max == max_val - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed - from google.api_core import grpc_helpers_async - - expected_grace = 9 expected_refresh = 0.5 - new_channel = grpc.aio.insecure_channel("localhost:8080") + grpc_lib = grpc.aio if CrossSync.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") - with mock.patch.object(asyncio, "sleep") as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError - ] + with mock.patch.object(CrossSync, "event_wait") as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [RuntimeError] with mock.patch.object( - grpc_helpers_async, "create_channel" + CrossSync.grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") create_channel.reset_mock() try: await client._manage_channel( refresh_interval_min=expected_refresh, refresh_interval_max=expected_refresh, - grace_period=expected_grace, + grace_period=0, ) - except asyncio.CancelledError: + except RuntimeError: pass assert sleep.call_count == num_cycles + 1 assert create_channel.call_count == num_cycles await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__register_instance(self): """ test instance registration @@ -483,7 +520,7 @@ async def test__register_instance(self): client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners client_mock._channel_refresh_task = None - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() await self._get_target_class()._register_instance( client_mock, "instance-1", table_mock @@ -535,7 +572,7 @@ async def test__register_instance(self): ] ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__register_instance_duplicate(self): """ test double instance registration. Should be no-op @@ -547,10 +584,10 @@ async def test__register_instance_duplicate(self): instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [object()] + client_mock._channel_refresh_task = object() mock_channels = [mock.Mock()] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() expected_key = ( "prefix/instance-1", @@ -577,7 +614,7 @@ async def test__register_instance_duplicate(self): assert expected_key == tuple(list(instance_owners)[0]) assert client_mock._ping_and_warm_instances.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "insert_instances,expected_active,expected_owner_keys", [ @@ -604,13 +641,8 @@ async def test__register_instance_state( instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._channel_refresh_task = None + client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() # register instances for instance, table, profile in insert_instances: @@ -636,9 +668,9 @@ async def test__register_instance_state( ] ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__remove_instance_registration(self): - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") table = mock.Mock() await client._register_instance("instance-1", table) await client._register_instance("instance-2", table) @@ -667,16 +699,16 @@ async def test__remove_instance_registration(self): assert len(client._active_instances) == 1 await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__multiple_table_registration(self): """ registering with multiple tables with the same key should add multiple owners to instance_owners, but only keep one copy of shared key in active_instances """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" @@ -689,12 +721,20 @@ async def test__multiple_table_registration(self): assert id(table_1) in client._instance_owners[instance_1_key] # duplicate table should register in instance_owners under same key async with client.get_table("instance_1", "table_1") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_2._register_instance_future.result() assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] # unique table should register in instance_owners and active_instances async with client.get_table("instance_1", "table_3") as table_3: + assert table_3._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_3._register_instance_future.result() instance_3_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -716,17 +756,25 @@ async def test__multiple_table_registration(self): assert instance_1_key not in client._active_instances assert len(client._instance_owners[instance_1_key]) == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test__multiple_instance_registration(self): """ registering with multiple instance keys should update the key in instance_owners and active_instances """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: + assert table_1._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_1._register_instance_future.result() async with client.get_table("instance_2", "table_2") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_2._register_instance_future.result() instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -755,12 +803,11 @@ async def test__multiple_instance_registration(self): assert len(client._instance_owners[instance_1_key]) == 0 assert len(client._instance_owners[instance_2_key]) == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") assert not client._active_instances expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -770,8 +817,8 @@ async def test_get_table(self): expected_table_id, expected_app_profile_id, ) - await asyncio.sleep(0) - assert isinstance(table, TableAsync) + await CrossSync.yield_to_event_loop() + assert isinstance(table, CrossSync.TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -791,14 +838,14 @@ async def test_get_table(self): assert client._instance_owners[instance_key] == {id(table)} await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table_arg_passthrough(self): """ All arguments passed in get_table should be sent to constructor """ - async with self._make_one(project="project-id") as client: - with mock.patch( - "google.cloud.bigtable.data._async.client.TableAsync.__init__", + async with self._make_client(project="project-id") as client: + with mock.patch.object( + CrossSync.TestTable._get_target_class(), "__init__" ) as mock_constructor: mock_constructor.return_value = None assert not client._active_instances @@ -824,25 +871,26 @@ async def test_get_table_arg_passthrough(self): **expected_kwargs, ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table_context_manager(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object(TableAsync, "close") as close_mock: - async with self._make_one(project=expected_project_id) as client: + with mock.patch.object( + CrossSync.TestTable._get_target_class(), "close" + ) as close_mock: + async with self._make_client(project=expected_project_id) as client: async with client.get_table( expected_instance_id, expected_table_id, expected_app_profile_id, ) as table: - await asyncio.sleep(0) - assert isinstance(table, TableAsync) + await CrossSync.yield_to_event_loop() + assert isinstance(table, CrossSync.TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -862,53 +910,63 @@ async def test_get_table_context_manager(self): assert client._instance_owners[instance_key] == {id(table)} assert close_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_close(self): - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) task = client._channel_refresh_task assert task is not None assert not task.done() - with mock.patch.object(client.transport, "close", AsyncMock()) as close_mock: + with mock.patch.object( + client.transport, "close", CrossSync.Mock() + ) as close_mock: await client.close() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() assert task.done() - assert task.cancelled() assert client._channel_refresh_task is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_close_with_timeout(self): expected_timeout = 19 - client = self._make_one(project="project-id", use_emulator=False) - with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: + client = self._make_client(project="project-id", use_emulator=False) + with mock.patch.object(CrossSync, "wait", CrossSync.Mock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() - wait_for_mock.assert_awaited() + if CrossSync.is_async: + wait_for_mock.assert_awaited() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_context_manager(self): + from functools import partial + # context manager should close the client cleanly - close_mock = AsyncMock() + close_mock = CrossSync.Mock() true_close = None - async with self._make_one(project="project-id", use_emulator=False) as client: - true_close = client.close() + async with self._make_client( + project="project-id", use_emulator=False + ) as client: + # grab reference to close coro for async test + true_close = partial(client.close) client.close = close_mock assert not client._channel_refresh_task.done() assert client.project == "project-id" assert client._active_instances == set() close_mock.assert_not_called() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() # actually close the client - await true_close + await true_close() + @CrossSync.drop def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError with pytest.warns(RuntimeWarning) as warnings: - client = _make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) expected_warning = [w for w in warnings if "client.py" in w.filename] assert len(expected_warning) == 1 assert ( @@ -919,11 +977,20 @@ def test_client_ctor_sync(self): assert client._channel_refresh_task is None +@CrossSync.convert_class("TestTable", add_mapping_for_name="TestTable") class TestTableAsync: - @pytest.mark.asyncio + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @staticmethod + @CrossSync.convert + def _get_target_class(): + return CrossSync.Table + + @CrossSync.pytest async def test_table_ctor(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -934,10 +1001,10 @@ async def test_table_ctor(self): expected_read_rows_attempt_timeout = 0.5 expected_mutate_rows_operation_timeout = 2.5 expected_mutate_rows_attempt_timeout = 0.75 - client = _make_client() + client = self._make_client() assert not client._active_instances - table = TableAsync( + table = self._get_target_class()( client, expected_instance_id, expected_table_id, @@ -949,7 +1016,7 @@ async def test_table_ctor(self): default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id == expected_app_profile_id @@ -978,30 +1045,28 @@ async def test_table_ctor(self): == expected_mutate_rows_attempt_timeout ) # ensure task reaches completion - await table._register_instance_task - assert table._register_instance_task.done() - assert not table._register_instance_task.cancelled() - assert table._register_instance_task.exception() is None + await table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_table_ctor_defaults(self): """ should provide default timeout values and app_profile_id """ - from google.cloud.bigtable.data._async.client import TableAsync - expected_table_id = "table-id" expected_instance_id = "instance-id" - client = _make_client() + client = self._make_client() assert not client._active_instances - table = TableAsync( + table = self._get_target_class()( client, expected_instance_id, expected_table_id, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id is None @@ -1014,14 +1079,12 @@ async def test_table_ctor_defaults(self): assert table.default_mutate_rows_attempt_timeout == 60 await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_table_ctor_invalid_timeout_values(self): """ bad timeout values should raise ValueError """ - from google.cloud.bigtable.data._async.client import TableAsync - - client = _make_client() + client = self._make_client() timeout_pairs = [ ("default_operation_timeout", "default_attempt_timeout"), @@ -1036,68 +1099,67 @@ async def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{attempt_timeout: -1}) + self._get_target_class()(client, "", "", **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{operation_timeout: -1}) + self._get_target_class()(client, "", "", **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) await client.close() + @CrossSync.drop def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError - from google.cloud.bigtable.data._async.client import TableAsync - client = mock.Mock() with pytest.raises(RuntimeError) as e: TableAsync(client, "instance-id", "table-id") assert e.match("TableAsync must be created within an async event loop context.") - @pytest.mark.asyncio + @CrossSync.pytest # iterate over all retryable rpcs @pytest.mark.parametrize( - "fn_name,fn_args,retry_fn_path,extra_retryables", + "fn_name,fn_args,is_stream,extra_retryables", [ ( "read_rows_stream", (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_rows", (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_row", (b"row_key",), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_rows_sharded", ([ReadRowsQuery()],), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "row_exists", (b"row_key",), - "google.api_core.retry.retry_target_stream_async", + True, (), ), - ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()), + ("sample_row_keys", (), False, ()), ( "mutate_row", (b"row_key", [mock.Mock()]), - "google.api_core.retry.retry_target_async", + False, (), ), ( "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), - "google.api_core.retry.retry_target_async", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, (_MutateRowsIncomplete,), ), ], @@ -1132,17 +1194,26 @@ async def test_customizable_retryable_errors( expected_retryables, fn_name, fn_args, - retry_fn_path, + is_stream, extra_retryables, ): """ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - with mock.patch(retry_fn_path) as retry_fn_mock: - async with _make_client() as client: + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if CrossSync.is_async: + retry_fn = f"CrossSync.{retry_fn}" + else: + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + with mock.patch( + f"google.cloud.bigtable.data._cross_sync.{retry_fn}" + ) as retry_fn_mock: + async with self._make_client() as client: table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = expected_retryables.__contains__ retry_fn_mock.side_effect = RuntimeError("stop early") with mock.patch( "google.api_core.retry.if_exception_type" @@ -1184,20 +1255,22 @@ async def test_customizable_retryable_errors( ], ) @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - """check that all requests attach proper metadata headers""" - from google.cloud.bigtable.data import TableAsync - profile = "profile" if include_app_profile else None - client = _make_client() + client = self._make_client() # create mock for rpc stub transport_mock = mock.MagicMock() - rpc_mock = mock.AsyncMock() + rpc_mock = CrossSync.Mock() transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock - client._gapic_client._client._transport = transport_mock - client._gapic_client._client._is_universe_domain_valid = True - table = TableAsync(client, "instance-id", "table-id", profile) + gapic_client = client._gapic_client + if CrossSync.is_async: + # inner BigtableClient is held as ._client for BigtableAsyncClient + gapic_client = gapic_client._client + gapic_client._transport = transport_mock + gapic_client._is_universe_domain_valid = True + table = self._get_target_class()(client, "instance-id", "table-id", profile) try: test_fn = table.__getattribute__(fn_name) maybe_stream = await test_fn(*fn_args) @@ -1220,20 +1293,32 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in routing_str -class TestReadRows: +@CrossSync.convert_class( + "TestReadRows", + add_mapping_for_name="TestReadRows", +) +class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ - def _make_table(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import TableAsync + @staticmethod + @CrossSync.convert + def _get_operation_class(): + return CrossSync._ReadRowsOperation + + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + @CrossSync.convert + def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync.yield_to_event_loop() ) client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync.yield_to_event_loop() ) kwargs["instance_id"] = kwargs.get( "instance_id", args[0] if args else "instance" @@ -1243,7 +1328,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return TableAsync(client_mock, *args, **kwargs) + return CrossSync.TestTable._get_target_class()(client_mock, *args, **kwargs) def _make_stats(self): from google.cloud.bigtable_v2.types import RequestStats @@ -1274,6 +1359,7 @@ def _make_chunk(*args, **kwargs): return ReadRowsResponse.CellChunk(*args, **kwargs) @staticmethod + @CrossSync.convert async def _make_gapic_stream( chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0, @@ -1286,30 +1372,33 @@ def __init__(self, chunk_list, sleep_time): self.idx = -1 self.sleep_time = sleep_time + @CrossSync.convert(sync_name="__iter__") def __aiter__(self): return self + @CrossSync.convert(sync_name="__next__") async def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: - await asyncio.sleep(self.sleep_time) + await CrossSync.sleep(self.sleep_time) chunk = self.chunk_list[self.idx] if isinstance(chunk, Exception): raise chunk else: return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration + raise CrossSync.StopIteration def cancel(self): pass return mock_stream(chunk_list, sleep_time) + @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows(self): query = ReadRowsQuery() chunks = [ @@ -1326,7 +1415,7 @@ async def test_read_rows(self): assert results[0].row_key == b"test_1" assert results[1].row_key == b"test_2" - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_stream(self): query = ReadRowsQuery() chunks = [ @@ -1345,7 +1434,7 @@ async def test_read_rows_stream(self): assert results[1].row_key == b"test_2" @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_query_matches_request(self, include_app_profile): from google.cloud.bigtable.data import RowRange from google.cloud.bigtable.data.row_filters import PassAllFilter @@ -1372,14 +1461,14 @@ async def test_read_rows_query_matches_request(self, include_app_profile): assert call_request == query_pb @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_timeout(self, operation_timeout): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows query = ReadRowsQuery() chunks = [self._make_chunk(row_key=b"test_1")] read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=1 + chunks, sleep_time=0.15 ) try: await table.read_rows(query, operation_timeout=operation_timeout) @@ -1397,7 +1486,7 @@ async def test_read_rows_timeout(self, operation_timeout): (0.05, 0.24, 5), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_attempt_timeout( self, per_request_t, operation_t, expected_num ): @@ -1460,7 +1549,7 @@ async def test_read_rows_attempt_timeout( core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1491,7 +1580,7 @@ async def test_read_rows_retryable_error(self, exc_type): InvalidChunk, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_non_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1505,18 +1594,17 @@ async def test_read_rows_non_retryable_error(self, exc_type): except exc_type as e: assert e == expected_error - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_revise_request(self): """ Ensure that _revise_request is called between retries """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import RowSet return_val = RowSet() with mock.patch.object( - _ReadRowsOperationAsync, "_revise_request_rowset" + self._get_operation_class(), "_revise_request_rowset" ) as revise_rowset: revise_rowset.return_value = return_val async with self._make_table() as table: @@ -1540,16 +1628,14 @@ async def test_read_rows_revise_request(self): revised_call = read_rows.call_args_list[1].args[0] assert revised_call.rows == return_val - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_default_timeouts(self): """ Ensure that the default timeouts are set on the read rows operation when not overridden """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_read_rows_operation_timeout=operation_timeout, @@ -1563,16 +1649,14 @@ async def test_read_rows_default_timeouts(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_default_timeout_override(self): """ When timeouts are passed, they overwrite default values """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_operation_timeout=99, default_attempt_timeout=97 @@ -1589,10 +1673,10 @@ async def test_read_rows_default_timeout_override(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row(self): """Test reading a single row""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1617,10 +1701,10 @@ async def test_read_row(self): assert query.row_ranges == [] assert query.limit == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_w_filter(self): """Test reading a single row with an added filter""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1650,10 +1734,10 @@ async def test_read_row_w_filter(self): assert query.limit == 1 assert query.filter == expected_filter - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_no_response(self): """should return None if row does not exist""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1685,10 +1769,10 @@ async def test_read_row_no_response(self): ([object(), object()], True), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_row_exists(self, return_value, expected_result): """Test checking for row existence""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1722,32 +1806,35 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -class TestReadRowsSharded: - @pytest.mark.asyncio +@CrossSync.convert_class("TestReadRowsSharded") +class TestReadRowsShardedAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.pytest async def test_read_rows_sharded_empty_query(self): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as exc: await table.read_rows_sharded([]) assert "empty sharded_query" in str(exc.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "read_rows" ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRows._make_gapic_stream( - [ - TestReadRows._make_chunk(row_key=k) - for k in args[0].rows.row_keys - ] - ) + read_rows.side_effect = lambda *args, **kwargs: CrossSync.TestReadRows._make_gapic_stream( + [ + CrossSync.TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] ) query_1 = ReadRowsQuery(b"test_1") query_2 = ReadRowsQuery(b"test_2") @@ -1757,19 +1844,19 @@ async def test_read_rows_sharded_multiple_queries(self): assert result[1].row_key == b"test_2" @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): """ Each query should trigger a separate read_rows call """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: query_list = [ReadRowsQuery() for _ in range(n_queries)] await table.read_rows_sharded(query_list) assert read_rows.call_count == n_queries - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_errors(self): """ Errors should be exposed as ShardedReadRowsExceptionGroups @@ -1777,7 +1864,7 @@ async def test_read_rows_sharded_errors(self): from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedQueryShardError - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = RuntimeError("mock error") @@ -1797,7 +1884,7 @@ async def test_read_rows_sharded_errors(self): assert exc.value.exceptions[1].index == 1 assert exc.value.exceptions[1].query == query_2 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_concurrent(self): """ Ensure sharded requests are concurrent @@ -1805,10 +1892,10 @@ async def test_read_rows_sharded_concurrent(self): import time async def mock_call(*args, **kwargs): - await asyncio.sleep(0.1) + await CrossSync.sleep(0.1) return [mock.Mock()] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1821,14 +1908,14 @@ async def mock_call(*args, **kwargs): # if run in sequence, we would expect this to take 1 second assert call_time < 0.2 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_concurrency_limit(self): """ Only 10 queries should be processed concurrently. Others should be queued Should start a new query as soon as previous finishes """ - from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT assert _CONCURRENCY_LIMIT == 10 # change this test if this changes num_queries = 15 @@ -1846,7 +1933,7 @@ async def mock_call(*args, **kwargs): starting_timeout = 10 - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1870,13 +1957,13 @@ async def mock_call(*args, **kwargs): idx = i + _CONCURRENCY_LIMIT assert rpc_start_list[idx] - (i * increment_time) < eps - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_expirary(self): """ If the operation times out before all shards complete, should raise a ShardedReadRowsExceptionGroup """ - from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.api_core.exceptions import DeadlineExceeded @@ -1896,7 +1983,7 @@ async def mock_call(*args, **kwargs): await asyncio.sleep(next_item) return [mock.Mock()] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1910,7 +1997,7 @@ async def mock_call(*args, **kwargs): # should keep successful queries assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_negative_batch_timeout(self): """ try to run with batch that starts after operation timeout @@ -1921,10 +2008,10 @@ async def test_read_rows_sharded_negative_batch_timeout(self): from google.api_core.exceptions import DeadlineExceeded async def mock_call(*args, **kwargs): - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) return [mock.Mock()] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1939,14 +2026,20 @@ async def mock_call(*args, **kwargs): ) -class TestSampleRowKeys: +@CrossSync.convert_class("TestSampleRowKeys") +class TestSampleRowKeysAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.convert async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse for value in sample_list: yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys(self): """ Test that method returns the expected key samples @@ -1956,10 +2049,10 @@ async def test_sample_row_keys(self): (b"test_2", 100), (b"test_3", 200), ] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream(samples) result = await table.sample_row_keys() @@ -1971,12 +2064,12 @@ async def test_sample_row_keys(self): assert result[1] == samples[1] assert result[2] == samples[2] - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_bad_timeout(self): """ should raise error if timeout is negative """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.sample_row_keys(operation_timeout=-1) @@ -1985,11 +2078,11 @@ async def test_sample_row_keys_bad_timeout(self): await table.sample_row_keys(attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_default_timeout(self): """Should fallback to using table default operation_timeout""" expected_timeout = 99 - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "i", "t", @@ -1997,7 +2090,7 @@ async def test_sample_row_keys_default_timeout(self): default_attempt_timeout=expected_timeout, ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) result = await table.sample_row_keys() @@ -2006,7 +2099,7 @@ async def test_sample_row_keys_default_timeout(self): assert result == [] assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_gapic_params(self): """ make sure arguments are propagated to gapic call as expected @@ -2015,12 +2108,12 @@ async def test_sample_row_keys_gapic_params(self): expected_profile = "test1" instance = "instance_name" table_id = "my_table" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( instance, table_id, app_profile_id=expected_profile ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) await table.sample_row_keys(attempt_timeout=expected_timeout) @@ -2039,7 +2132,7 @@ async def test_sample_row_keys_gapic_params(self): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_retryable_errors(self, retryable_exception): """ retryable errors should be retried until timeout @@ -2047,10 +2140,10 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.side_effect = retryable_exception("mock") with pytest.raises(DeadlineExceeded) as e: @@ -2071,23 +2164,28 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): core_exceptions.Aborted, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): """ non-retryable errors should cause a raise """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.side_effect = non_retryable_exception("mock") with pytest.raises(non_retryable_exception): await table.sample_row_keys() -class TestMutateRow: - @pytest.mark.asyncio +@CrossSync.convert_class("TestMutateRow") +class TestMutateRowAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.pytest @pytest.mark.parametrize( "mutation_arg", [ @@ -2108,7 +2206,7 @@ class TestMutateRow: async def test_mutate_row(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2143,12 +2241,12 @@ async def test_mutate_row(self, mutation_arg): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2171,14 +2269,14 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_non_idempotent_retryable_errors( self, retryable_exception ): """ Non-idempotent mutations should not be retried """ - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2204,9 +2302,9 @@ async def test_mutate_row_non_idempotent_retryable_errors( core_exceptions.Aborted, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2225,16 +2323,22 @@ async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): ) @pytest.mark.parametrize("mutations", [[], None]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_no_mutations(self, mutations): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.mutate_row("key", mutations=mutations) assert e.value.args[0] == "No mutations provided" -class TestBulkMutateRows: +@CrossSync.convert_class("TestBulkMutateRows") +class TestBulkMutateRowsAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.convert async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -2254,13 +2358,14 @@ async def _mock_response(self, response_list): for i in range(len(response_list)) ] + @CrossSync.convert async def generator(): yield MutateRowsResponse(entries=entries) return generator() - @pytest.mark.asyncio - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.pytest @pytest.mark.parametrize( "mutation_arg", [ @@ -2283,7 +2388,7 @@ async def generator(): async def test_bulk_mutate_rows(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2304,10 +2409,10 @@ async def test_bulk_mutate_rows(self, mutation_arg): assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_rows_multiple_entries(self): """Test mutations with no errors""" - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2328,7 +2433,7 @@ async def test_bulk_mutate_rows_multiple_entries(self): assert kwargs["entries"][0] == entry_1._to_pb() assert kwargs["entries"][1] == entry_2._to_pb() - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "exception", [ @@ -2348,7 +2453,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2373,7 +2478,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( cause.exceptions[-1], core_exceptions.DeadlineExceeded ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "exception", [ @@ -2394,7 +2499,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2421,7 +2526,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_idempotent_retryable_request_errors( self, retryable_exception ): @@ -2434,7 +2539,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2455,7 +2560,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( assert isinstance(cause, RetryExceptionGroup) assert isinstance(cause.exceptions[0], retryable_exception) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "retryable_exception", [ @@ -2472,7 +2577,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2504,7 +2609,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( ValueError, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): """ If the request fails with a non-retryable error, mutations should not be retried @@ -2514,7 +2619,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2534,7 +2639,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti cause = failed_exception.__cause__ assert isinstance(cause, non_retryable_exception) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_error_index(self): """ Test partial failure, partial success. Errors should be associated with the correct index @@ -2550,7 +2655,7 @@ async def test_bulk_mutate_error_index(self): MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2585,14 +2690,14 @@ async def test_bulk_mutate_error_index(self): assert isinstance(cause.exceptions[1], DeadlineExceeded) assert isinstance(cause.exceptions[2], FailedPrecondition) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_error_recovery(self): """ If an error occurs, then resolves, no exception should be raised """ from google.api_core.exceptions import DeadlineExceeded - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: table = client.get_table("instance", "table") with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: # fail with a retryable error, then a non-retryable one @@ -2610,14 +2715,19 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -class TestCheckAndMutateRow: +@CrossSync.convert_class("TestCheckAndMutateRow") +class TestCheckAndMutateRowAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + @pytest.mark.parametrize("gapic_result", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate(self, gapic_result): from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse app_profile = "app_profile_id" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "instance", "table", app_profile_id=app_profile ) as table: @@ -2654,10 +2764,10 @@ async def test_check_and_mutate(self, gapic_result): assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_bad_timeout(self): """Should raise error if operation_timeout < 0""" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.check_and_mutate_row( @@ -2669,13 +2779,13 @@ async def test_check_and_mutate_bad_timeout(self): ) assert str(e.value) == "operation_timeout must be greater than 0" - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_single_mutations(self): """if single mutations are passed, they should be internally wrapped in a list""" from google.cloud.bigtable.data.mutations import SetCell from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2695,7 +2805,7 @@ async def test_check_and_mutate_single_mutations(self): assert kwargs["true_mutations"] == [true_mutation._to_pb()] assert kwargs["false_mutations"] == [false_mutation._to_pb()] - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_predicate_object(self): """predicate filter should be passed to gapic request""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -2703,7 +2813,7 @@ async def test_check_and_mutate_predicate_object(self): mock_predicate = mock.Mock() predicate_pb = {"predicate": "dict"} mock_predicate._to_pb.return_value = predicate_pb - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2721,7 +2831,7 @@ async def test_check_and_mutate_predicate_object(self): assert mock_predicate._to_pb.call_count == 1 assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_mutations_parsing(self): """mutations objects should be converted to protos""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -2731,7 +2841,7 @@ async def test_check_and_mutate_mutations_parsing(self): for idx, mutation in enumerate(mutations): mutation._to_pb.return_value = f"fake {idx}" mutations.append(DeleteAllFromRow()) - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2758,7 +2868,12 @@ async def test_check_and_mutate_mutations_parsing(self): ) -class TestReadModifyWriteRow: +@CrossSync.convert_class("TestReadModifyWriteRow") +class TestReadModifyWriteRowAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + @pytest.mark.parametrize( "call_rules,expected_rules", [ @@ -2780,12 +2895,12 @@ class TestReadModifyWriteRow: ), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): """ Test that the gapic call is called with given rules """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2797,21 +2912,21 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_no_rules(self, rules): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.read_modify_write_row("key", rules=rules) assert e.value.args[0] == "rules must contain at least one item" - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_defaults(self): instance = "instance1" table_id = "table1" project = "project1" row_key = "row_key1" - async with _make_client(project=project) as client: + async with self._make_client(project=project) as client: async with client.get_table(instance, table_id) as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2827,12 +2942,12 @@ async def test_read_modify_write_call_defaults(self): assert kwargs["row_key"] == row_key.encode() assert kwargs["timeout"] > 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_overrides(self): row_key = b"row_key1" expected_timeout = 12345 profile_id = "profile1" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "instance", "table_id", app_profile_id=profile_id ) as table: @@ -2850,10 +2965,10 @@ async def test_read_modify_write_call_overrides(self): assert kwargs["row_key"] == row_key assert kwargs["timeout"] == expected_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_string_key(self): row_key = "string_row_key1" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2863,7 +2978,7 @@ async def test_read_modify_write_string_key(self): kwargs = mock_gapic.call_args_list[0][1] assert kwargs["row_key"] == row_key.encode() - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_row_building(self): """ results from gapic call should be used to construct row @@ -2873,7 +2988,7 @@ async def test_read_modify_write_row_building(self): from google.cloud.bigtable_v2.types import Row as RowPB mock_response = ReadModifyWriteRowResponse(row=RowPB()) - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2883,3 +2998,363 @@ async def test_read_modify_write_row_building(self): await table.read_modify_write_row("key", mock.Mock()) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) + + +@CrossSync.convert_class("TestExecuteQuery") +class TestExecuteQueryAsync: + TABLE_NAME = "TABLE_NAME" + INSTANCE_NAME = "INSTANCE_NAME" + + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.convert + def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): + class MockStream: + def __init__(self, sample_list): + self.sample_list = sample_list + + def __aiter__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + if not self.sample_list: + raise CrossSync.StopIteration + value = self.sample_list.pop(0) + if isinstance(value, Exception): + raise value + return value + + async def __anext__(self): + return self.__next__() + + return MockStream(sample_list) + + def resonse_with_metadata(self): + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + schema = {"a": "string_type", "b": "int64_type"} + return ExecuteQueryResponse( + { + "metadata": { + "proto_schema": { + "columns": [ + {"name": name, "type_": {_type: {}}} + for name, _type in schema.items() + ] + } + } + } + ) + + def resonse_with_result(self, *args, resume_token=None): + from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + if resume_token is None: + resume_token_dict = {} + else: + resume_token_dict = {"resume_token": resume_token} + + values = [] + for column_value in args: + if column_value is None: + pb_value = PBValue({}) + else: + pb_value = PBValue( + { + "int_value" + if isinstance(column_value, int) + else "string_value": column_value + } + ) + values.append(pb_value) + rows = ProtoRows(values=values) + + return ExecuteQueryResponse( + { + "results": { + "proto_rows_batch": { + "batch_data": ProtoRows.serialize(rows), + }, + **resume_token_dict, + } + } + ) + + @CrossSync.pytest + async def test_execute_query(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert execute_query_mock.call_count == 1 + + @CrossSync.pytest + async def test_execute_query_with_params(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME} WHERE b=@b", + self.INSTANCE_NAME, + parameters={"b": 9}, + ) + results = [r async for r in result] + assert len(results) == 1 + assert results[0]["a"] == "test2" + assert results[0]["b"] == 9 + assert execute_query_mock.call_count == 1 + + @CrossSync.pytest + async def test_execute_query_error_before_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + + @CrossSync.pytest + async def test_execute_query_error_after_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + @CrossSync.pytest + async def test_execute_query_with_retries(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + DeadlineExceeded(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + DeadlineExceeded(""), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert len(results) == 3 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"r1", b"r2"] + + @pytest.mark.parametrize( + "exception", + [ + (core_exceptions.DeadlineExceeded("")), + (core_exceptions.Aborted("")), + (core_exceptions.ServiceUnavailable("")), + ], + ) + @CrossSync.pytest + async def test_execute_query_retryable_error(self, exception): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + exception, + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 1 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + @CrossSync.pytest + async def test_execute_query_retry_partial_row(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + core_exceptions.DeadlineExceeded(""), + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + @pytest.mark.parametrize( + "ExceptionType", + [ + (core_exceptions.InvalidArgument), + (core_exceptions.FailedPrecondition), + (core_exceptions.PermissionDenied), + (core_exceptions.MethodNotImplemented), + (core_exceptions.Cancelled), + (core_exceptions.AlreadyExists), + (core_exceptions.OutOfRange), + (core_exceptions.DataLoss), + (core_exceptions.Unauthenticated), + (core_exceptions.NotFound), + (core_exceptions.ResourceExhausted), + (core_exceptions.Unknown), + (core_exceptions.InternalServerError), + ], + ) + @CrossSync.pytest + async def test_execute_query_non_retryable(self, ExceptionType): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + ExceptionType(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + r = await CrossSync.next(result) + assert r["a"] == "test" + assert r["b"] == 8 + + with pytest.raises(ExceptionType): + r = await CrossSync.next(result) + + assert execute_query_mock.call_count == 1 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + @CrossSync.pytest + async def test_execute_query_metadata_received_multiple_times_detected(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_metadata(), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + with pytest.raises( + Exception, match="Invalid ExecuteQuery response received" + ): + [ + r + async for r in await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + ] diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index cca7c9824..cd442d392 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -13,34 +13,35 @@ # limitations under the License. import pytest +import mock import asyncio +import time import google.api_core.exceptions as core_exceptions +import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT -# try/except added for compatibility with python < 3.8 -try: - from unittest import mock - from unittest.mock import AsyncMock -except ImportError: # pragma: NO COVER - import mock # type: ignore - from mock import AsyncMock # type: ignore +from google.cloud.bigtable.data._cross_sync import CrossSync +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_mutations_batcher" -def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation +@CrossSync.convert_class(sync_name="Test_FlowControl") +class Test_FlowControlAsync: + @staticmethod + @CrossSync.convert + def _target_class(): + return CrossSync._FlowControl -class Test_FlowControl: def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): - from google.cloud.bigtable.data._async.mutations_batcher import ( - _FlowControlAsync, - ) + return self._target_class()(max_mutation_count, max_mutation_bytes) - return _FlowControlAsync(max_mutation_count, max_mutation_bytes) + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation def test_ctor(self): max_mutation_count = 9 @@ -50,7 +51,7 @@ def test_ctor(self): assert instance._max_mutation_bytes == max_mutation_bytes assert instance._in_flight_mutation_count == 0 assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, asyncio.Condition) + assert isinstance(instance._capacity_condition, CrossSync.Condition) def test_ctor_invalid_values(self): """Test that values are positive, and fit within expected limits""" @@ -110,7 +111,7 @@ def test__has_capacity( instance._in_flight_mutation_bytes = existing_size assert instance._has_capacity(new_count, new_size) == expected - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "existing_count,existing_size,added_count,added_size,new_count,new_size", [ @@ -138,12 +139,12 @@ async def test_remove_from_flow_value_update( instance = self._make_one() instance._in_flight_mutation_count = existing_count instance._in_flight_mutation_bytes = existing_size - mutation = _make_mutation(added_count, added_size) + mutation = self._make_mutation(added_count, added_size) await instance.remove_from_flow(mutation) assert instance._in_flight_mutation_count == new_count assert instance._in_flight_mutation_bytes == new_size - @pytest.mark.asyncio + @CrossSync.pytest async def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" instance = self._make_one(10, 10) @@ -156,36 +157,50 @@ async def task_routine(): lambda: instance._has_capacity(1, 1) ) - task = asyncio.create_task(task_routine()) - await asyncio.sleep(0.05) + if CrossSync.is_async: + # for async class, build task to test flow unlock + task = asyncio.create_task(task_routine()) + + def task_alive(): + return not task.done() + + else: + # this branch will be tested in sync version of this test + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + await CrossSync.sleep(0.05) # should be blocked due to capacity - assert task.done() is False + assert task_alive() is True # try changing size - mutation = _make_mutation(count=0, size=5) + mutation = self._make_mutation(count=0, size=5) + await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 10 assert instance._in_flight_mutation_bytes == 5 - assert task.done() is False + assert task_alive() is True # try changing count instance._in_flight_mutation_bytes = 10 - mutation = _make_mutation(count=5, size=0) + mutation = self._make_mutation(count=5, size=0) await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 10 - assert task.done() is False + assert task_alive() is True # try changing both instance._in_flight_mutation_count = 10 - mutation = _make_mutation(count=5, size=5) + mutation = self._make_mutation(count=5, size=5) await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 5 # task should be complete - assert task.done() is True + assert task_alive() is False - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "mutations,count_cap,size_cap,expected_results", [ @@ -210,7 +225,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result """ Test batching with various flow control settings """ - mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] instance = self._make_one(count_cap, size_cap) i = 0 async for batch in instance.add_to_flow(mutation_objs): @@ -226,7 +241,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result i += 1 assert i == len(expected_results) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "mutations,max_limit,expected_results", [ @@ -242,11 +257,12 @@ async def test_add_to_flow_max_mutation_limits( Test flow control running up against the max API limit Should submit request early, even if the flow control has room for more """ - with mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ): - mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + subpath = "_async" if CrossSync.is_async else "_sync_autogen" + path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" + with mock.patch(path, max_limit): + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] # flow control has no limits except API restrictions instance = self._make_one(float("inf"), float("inf")) i = 0 @@ -263,14 +279,14 @@ async def test_add_to_flow_max_mutation_limits( i += 1 assert i == len(expected_results) - @pytest.mark.asyncio + @CrossSync.pytest async def test_add_to_flow_oversize(self): """ mutations over the flow control limits should still be accepted """ instance = self._make_one(2, 3) - large_size_mutation = _make_mutation(count=1, size=10) - large_count_mutation = _make_mutation(count=10, size=1) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) results = [out async for out in instance.add_to_flow([large_size_mutation])] assert len(results) == 1 await instance.remove_from_flow(results[0]) @@ -280,13 +296,11 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 +@CrossSync.convert_class(sync_name="TestMutationsBatcher") class TestMutationsBatcherAsync: + @CrossSync.convert def _get_target_class(self): - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - - return MutationsBatcherAsync + return CrossSync.MutationsBatcher def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded @@ -303,132 +317,140 @@ def _make_one(self, table=None, **kwargs): return self._get_target_class()(table, **kwargs) - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" - ) - @pytest.mark.asyncio - async def test_ctor_defaults(self, flush_timer_mock): - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = [Exception] - async with self._make_one(table) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._max_mutation_count == 100000 - assert instance._flow_control._max_mutation_bytes == 104857600 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert ( - instance._operation_timeout - == table.default_mutate_rows_operation_timeout - ) - assert ( - instance._attempt_timeout == table.default_mutate_rows_attempt_timeout - ) - assert ( - instance._retryable_errors == table.default_mutate_rows_retryable_errors - ) - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, asyncio.Future) + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer", - ) - @pytest.mark.asyncio - async def test_ctor_explicit(self, flush_timer_mock): + @CrossSync.pytest + async def test_ctor_defaults(self): + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + async with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors + == table.default_mutate_rows_retryable_errors + ) + await CrossSync.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, CrossSync.Future) + + @CrossSync.pytest + async def test_ctor_explicit(self): """Test with explicit parameters""" - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - flush_interval = 20 - flush_limit_count = 17 - flush_limit_bytes = 19 - flow_control_max_mutation_count = 1001 - flow_control_max_bytes = 12 - operation_timeout = 11 - attempt_timeout = 2 - retryable_errors = [Exception] - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=operation_timeout, - batch_attempt_timeout=attempt_timeout, - batch_retryable_errors=retryable_errors, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert ( - instance._flow_control._max_mutation_count - == flow_control_max_mutation_count - ) - assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert instance._operation_timeout == operation_timeout - assert instance._attempt_timeout == attempt_timeout - assert instance._retryable_errors == retryable_errors - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, asyncio.Future) - - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" - ) - @pytest.mark.asyncio - async def test_ctor_no_flush_limits(self, flush_timer_mock): + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() + ) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + await CrossSync.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, CrossSync.Future) + + @CrossSync.pytest + async def test_ctor_no_flush_limits(self): """Test with None for flush limits""" - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = () - flush_interval = None - flush_limit_count = None - flush_limit_bytes = None - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._staged_entries == [] - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, asyncio.Future) + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + await CrossSync.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, CrossSync.Future) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_invalid_values(self): """Test that timeout values are positive, and fit within expected limits""" with pytest.raises(ValueError) as e: @@ -438,24 +460,21 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) + @CrossSync.convert def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in table.mutations_batcher. Make sure any changes to defaults are applied to both places """ - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) import inspect get_batcher_signature = dict( - inspect.signature(TableAsync.mutations_batcher).parameters + inspect.signature(CrossSync.Table.mutations_batcher).parameters ) get_batcher_signature.pop("self") batcher_init_signature = dict( - inspect.signature(MutationsBatcherAsync).parameters + inspect.signature(self._get_target_class()).parameters ) batcher_init_signature.pop("table") # both should have same number of arguments @@ -470,97 +489,96 @@ def test_default_argument_consistency(self): == batcher_init_signature[arg_name].default ) - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__start_flush_timer_w_None(self, flush_mock): - """Empty timer should return immediately""" - async with self._make_one() as instance: - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(None) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 + @CrossSync.pytest + @pytest.mark.parametrize("input_val", [None, 0, -1]) + async def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + # mock different method depending on sync vs async + async with self._make_one() as instance: + if CrossSync.is_async: + sleep_obj, sleep_method = asyncio, "wait_for" + else: + sleep_obj, sleep_method = instance._closed, "wait" + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = await instance._timer_routine(input_val) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + assert result is None - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__start_flush_timer_call_when_closed(self, flush_mock): + @CrossSync.pytest + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + async def test__start_flush_timer_call_when_closed( + self, + ): """closed batcher's timer should return immediately""" - async with self._make_one() as instance: - await instance.close() - flush_mock.reset_mock() - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(1) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + async with self._make_one() as instance: + await instance.close() + flush_mock.reset_mock() + # mock different method depending on sync vs async + if CrossSync.is_async: + sleep_obj, sleep_method = asyncio, "wait_for" + else: + sleep_obj, sleep_method = instance._closed, "wait" + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + await instance._timer_routine(10) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__flush_timer(self, flush_mock): + @CrossSync.pytest + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + async def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - instance._staged_entries = [mock.Mock()] - loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] - try: - await instance._flush_timer - except asyncio.CancelledError: - pass - assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) - assert flush_mock.call_count == loop_num - - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__flush_timer_no_mutations(self, flush_mock): - """Timer should not flush if no new mutations have been staged""" - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] - try: - await instance._flush_timer - except asyncio.CancelledError: - pass - assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) - assert flush_mock.call_count == 0 + from google.cloud.bigtable.data._cross_sync import CrossSync - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__flush_timer_close(self, flush_mock): + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + instance._staged_entries = [mock.Mock()] * num_staged + with mock.patch.object(CrossSync, "event_wait") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + with pytest.raises(TabError): + await self._get_target_class()._timer_routine( + instance, expected_sleep + ) + if CrossSync.is_async: + # replace with np-op so there are no issues on close + instance._flush_timer = CrossSync.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) + + @CrossSync.pytest + async def test__flush_timer_close(self): """Timer should continue terminate after close""" - async with self._make_one() as instance: - with mock.patch("asyncio.sleep"): + with mock.patch.object(self._get_target_class(), "_schedule_flush"): + async with self._make_one() as instance: # let task run in background - await asyncio.sleep(0.5) assert instance._flush_timer.done() is False # close the batcher await instance.close() - await asyncio.sleep(0.1) # task should be complete assert instance._flush_timer.done() is True - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_closed(self): """Should raise exception""" + instance = self._make_one() + await instance.close() with pytest.raises(RuntimeError): - instance = self._make_one() - await instance.close() await instance.append(mock.Mock()) - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_wrong_mutation(self): """ Mutation objects should raise an exception. @@ -574,13 +592,13 @@ async def test_append_wrong_mutation(self): await instance.append(DeleteAllFromRow()) assert str(e.value) == expected_error - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_outside_flow_limits(self): """entries larger than mutation limits are still processed""" async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - oversized_entry = _make_mutation(count=0, size=2) + oversized_entry = self._make_mutation(count=0, size=2) await instance.append(oversized_entry) assert instance._staged_entries == [oversized_entry] assert instance._staged_count == 0 @@ -589,25 +607,21 @@ async def test_append_outside_flow_limits(self): async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - overcount_entry = _make_mutation(count=2, size=0) + overcount_entry = self._make_mutation(count=2, size=0) await instance.append(overcount_entry) assert instance._staged_entries == [overcount_entry] assert instance._staged_count == 2 assert instance._staged_bytes == 0 instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_flush_runs_after_limit_hit(self): """ If the user appends a bunch of entries above the flush limits back-to-back, it should still flush in a single task """ - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - with mock.patch.object( - MutationsBatcherAsync, "_execute_mutate_rows" + self._get_target_class(), "_execute_mutate_rows" ) as op_mock: async with self._make_one(flush_limit_bytes=100) as instance: # mock network calls @@ -616,13 +630,13 @@ async def mock_call(*args, **kwargs): op_mock.side_effect = mock_call # append a mutation just under the size limit - await instance.append(_make_mutation(size=99)) + await instance.append(self._make_mutation(size=99)) # append a bunch of entries back-to-back in a loop num_entries = 10 for _ in range(num_entries): - await instance.append(_make_mutation(size=1)) + await instance.append(self._make_mutation(size=1)) # let any flush jobs finish - await asyncio.gather(*instance._flush_jobs) + await instance._wait_for_batch_results(*instance._flush_jobs) # should have only flushed once, with large mutation and first mutation in loop assert op_mock.call_count == 1 sent_batch = op_mock.call_args[0][0] @@ -642,7 +656,8 @@ async def mock_call(*args, **kwargs): (1, 1, 0, 0, False), ], ) - @pytest.mark.asyncio + @CrossSync.pytest + @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_append( self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush ): @@ -653,7 +668,7 @@ async def test_append( assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = _make_mutation(count=mutation_count, size=mutation_bytes) + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == bool(expect_flush) @@ -662,7 +677,7 @@ async def test_append( assert instance._staged_entries == [mutation] instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_multiple_sequentially(self): """Append multiple mutations""" async with self._make_one( @@ -671,7 +686,7 @@ async def test_append_multiple_sequentially(self): assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = _make_mutation(count=2, size=3) + mutation = self._make_mutation(count=2, size=3) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == 0 @@ -690,7 +705,7 @@ async def test_append_multiple_sequentially(self): assert len(instance._staged_entries) == 3 instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_flush_flow_control_concurrent_requests(self): """ requests should happen in parallel if flow control breaks up single flush into batches @@ -698,14 +713,14 @@ async def test_flush_flow_control_concurrent_requests(self): import time num_calls = 10 - fake_mutations = [_make_mutation(count=1) for _ in range(num_calls)] + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] async with self._make_one(flow_control_max_mutation_count=1) as instance: with mock.patch.object( - instance, "_execute_mutate_rows", AsyncMock() + instance, "_execute_mutate_rows", CrossSync.Mock() ) as op_mock: # mock network calls async def mock_call(*args, **kwargs): - await asyncio.sleep(0.1) + await CrossSync.sleep(0.1) return [] op_mock.side_effect = mock_call @@ -713,15 +728,15 @@ async def mock_call(*args, **kwargs): # flush one large batch, that will be broken up into smaller batches instance._staged_entries = fake_mutations instance._schedule_flush() - await asyncio.sleep(0.01) + await CrossSync.sleep(0.01) # make room for new mutations for i in range(num_calls): await instance._flow_control.remove_from_flow( - [_make_mutation(count=1)] + [self._make_mutation(count=1)] ) - await asyncio.sleep(0.01) + await CrossSync.sleep(0.01) # allow flushes to complete - await asyncio.gather(*instance._flush_jobs) + await instance._wait_for_batch_results(*instance._flush_jobs) duration = time.monotonic() - start_time assert len(instance._oldest_exceptions) == 0 assert len(instance._newest_exceptions) == 0 @@ -729,7 +744,7 @@ async def mock_call(*args, **kwargs): assert duration < 0.5 assert op_mock.call_count == num_calls - @pytest.mark.asyncio + @CrossSync.pytest async def test_schedule_flush_no_mutations(self): """schedule flush should return None if no staged mutations""" async with self._make_one() as instance: @@ -738,11 +753,15 @@ async def test_schedule_flush_no_mutations(self): assert instance._schedule_flush() is None assert flush_mock.call_count == 0 - @pytest.mark.asyncio + @CrossSync.pytest + @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" async with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: + if not CrossSync.is_async: + # simulate operation + flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): mutation = mock.Mock() instance._staged_entries = [mutation] @@ -753,9 +772,10 @@ async def test_schedule_flush_with_mutations(self): assert instance._staged_entries == [] assert instance._staged_count == 0 assert instance._staged_bytes == 0 - assert flush_mock.call_count == i + assert flush_mock.call_count == 1 + flush_mock.reset_mock() - @pytest.mark.asyncio + @CrossSync.pytest async def test__flush_internal(self): """ _flush_internal should: @@ -775,7 +795,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [_make_mutation(count=1, size=1)] * num_entries + mutations = [self._make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -783,20 +803,28 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() - @pytest.mark.asyncio + @CrossSync.pytest async def test_flush_clears_job_list(self): """ a job should be added to _flush_jobs when _schedule_flush is called, and removed when it completes """ async with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal", AsyncMock()): - mutations = [_make_mutation(count=1, size=1)] + with mock.patch.object( + instance, "_flush_internal", CrossSync.Mock() + ) as flush_mock: + if not CrossSync.is_async: + # simulate operation + flush_mock.side_effect = lambda x: time.sleep(0.1) + mutations = [self._make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - await new_job + if CrossSync.is_async: + await new_job + else: + new_job.result() assert instance._flush_jobs == set() @pytest.mark.parametrize( @@ -811,7 +839,7 @@ async def test_flush_clears_job_list(self): (10, 20, 20), # should cap at 20 ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__flush_internal_with_errors( self, num_starting, num_new_errors, expected_total_errors ): @@ -836,7 +864,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [_make_mutation(count=1, size=1)] * num_entries + mutations = [self._make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -853,10 +881,12 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() + @CrossSync.convert async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 + @CrossSync.convert async def gen(num): for i in range(num): entry = MutateRowsResponse.Entry( @@ -866,11 +896,11 @@ async def gen(num): return gen(num) - @pytest.mark.asyncio + @CrossSync.pytest async def test_timer_flush_end_to_end(self): """Flush should automatically trigger after flush_interval""" - num_nutations = 10 - mutations = [_make_mutation(count=2, size=2)] * num_nutations + num_mutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_mutations async with self._make_one(flush_interval=0.05) as instance: instance._table.default_operation_timeout = 10 @@ -879,69 +909,65 @@ async def test_timer_flush_end_to_end(self): instance._table.client._gapic_client, "mutate_rows" ) as gapic_mock: gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + lambda *args, **kwargs: self._mock_gapic_return(num_mutations) ) for m in mutations: await instance.append(m) assert instance._entries_processed_since_last_raise == 0 # let flush trigger due to timer - await asyncio.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_nutations - - @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", - ) - async def test__execute_mutate_rows(self, mutate_rows): - mutate_rows.return_value = AsyncMock() - start_operation = mutate_rows().start - table = mock.Mock() - table.table_name = "test-table" - table.app_profile_id = "test-app-profile" - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert start_operation.call_count == 1 - args, kwargs = mutate_rows.call_args - assert args[0] == table.client._gapic_client - assert args[1] == table - assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 - assert result == [] - - @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync.start" - ) - async def test__execute_mutate_rows_returns_errors(self, mutate_rows): + await CrossSync.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_mutations + + @CrossSync.pytest + async def test__execute_mutate_rows(self): + with mock.patch.object(CrossSync, "_MutateRowsOperation") as mutate_rows: + mutate_rows.return_value = CrossSync.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + @CrossSync.pytest + async def test__execute_mutate_rows_returns_errors(self): """Errors from operation should be retruned as list""" from google.cloud.bigtable.data.exceptions import ( MutationsExceptionGroup, FailedMutationEntryError, ) - err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) - err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) - mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert len(result) == 2 - assert result[0] == err1 - assert result[1] == err2 - # indices should be set to None - assert result[0].index is None - assert result[1].index is None - - @pytest.mark.asyncio + with mock.patch.object(CrossSync._MutateRowsOperation, "start") as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + # indices should be set to None + assert result[0].index is None + assert result[1].index is None + + @CrossSync.pytest async def test__raise_exceptions(self): """Raise exceptions and reset error state""" from google.cloud.bigtable.data import exceptions @@ -961,13 +987,19 @@ async def test__raise_exceptions(self): # try calling again instance._raise_exceptions() - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert( + sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"} + ) async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert( + sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"} + ) async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -975,7 +1007,7 @@ async def test___aexit__(self): await instance.__aexit__(None, None, None) assert close_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_close(self): """Should clean up all resources""" async with self._make_one() as instance: @@ -988,7 +1020,7 @@ async def test_close(self): assert flush_mock.call_count == 1 assert raise_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_close_w_exceptions(self): """Raise exceptions on close""" from google.cloud.bigtable.data import exceptions @@ -1007,7 +1039,7 @@ async def test_close_w_exceptions(self): # clear out exceptions instance._oldest_exceptions, instance._newest_exceptions = ([], []) - @pytest.mark.asyncio + @CrossSync.pytest async def test__on_exit(self, recwarn): """Should raise warnings if unflushed mutations exist""" async with self._make_one() as instance: @@ -1023,13 +1055,13 @@ async def test__on_exit(self, recwarn): assert "unflushed mutations" in str(w[0].message).lower() assert str(num_left) in str(w[0].message) # calling while closed is noop - instance.closed = True + instance._closed.set() instance._on_exit() assert len(recwarn) == 0 # reset staged mutations for cleanup instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_atexit_registration(self): """Should run _on_exit on program termination""" import atexit @@ -1039,30 +1071,29 @@ async def test_atexit_registration(self): async with self._make_one(): assert register_mock.call_count == 1 - @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", - ) - async def test_timeout_args_passed(self, mutate_rows): + @CrossSync.pytest + async def test_timeout_args_passed(self): """ batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - mutate_rows.return_value = AsyncMock() - expected_operation_timeout = 17 - expected_attempt_timeout = 13 - async with self._make_one( - batch_operation_timeout=expected_operation_timeout, - batch_attempt_timeout=expected_attempt_timeout, - ) as instance: - assert instance._operation_timeout == expected_operation_timeout - assert instance._attempt_timeout == expected_attempt_timeout - # make simulated gapic call - await instance._execute_mutate_rows([_make_mutation()]) - assert mutate_rows.call_count == 1 - kwargs = mutate_rows.call_args[1] - assert kwargs["operation_timeout"] == expected_operation_timeout - assert kwargs["attempt_timeout"] == expected_attempt_timeout + with mock.patch.object( + CrossSync, "_MutateRowsOperation", return_value=CrossSync.Mock() + ) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + async with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + # make simulated gapic call + await instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout @pytest.mark.parametrize( "limit,in_e,start_e,end_e", @@ -1123,7 +1154,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): for i in range(1, newest_list_diff + 1): assert mock_batcher._newest_exceptions[-i] == input_list[-i] - @pytest.mark.asyncio + @CrossSync.pytest # test different inputs for retryable exceptions @pytest.mark.parametrize( "input_retryables,expected_retryables", @@ -1148,6 +1179,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) + @CrossSync.convert async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): @@ -1155,25 +1187,21 @@ async def test_customizable_retryable_errors( Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - from google.cloud.bigtable.data._async.client import TableAsync - - with mock.patch( - "google.api_core.retry.if_exception_type" + with mock.patch.object( + google.api_core.retry, "if_exception_type" ) as predicate_builder_mock: - with mock.patch( - "google.api_core.retry.retry_target_async" - ) as retry_fn_mock: + with mock.patch.object(CrossSync, "retry_target") as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): - table = TableAsync(mock.Mock(), "instance", "table") + table = CrossSync.Table(mock.Mock(), "instance", "table") async with self._make_one( table, batch_retryable_errors=input_retryables ) as instance: assert instance._retryable_errors == expected_retryables - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = expected_retryables.__contains__ predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") - mutation = _make_mutation(count=1, size=1) + mutation = self._make_mutation(count=1, size=1) await instance._execute_mutate_rows([mutation]) # passed in errors should be used to build the predicate predicate_builder_mock.assert_called_once_with( @@ -1182,3 +1210,25 @@ async def test_customizable_retryable_errors( retry_call_args = retry_fn_mock.call_args_list[0].args # output of if_exception_type should be sent in to retry constructor assert retry_call_args[1] is expected_predicate + + @CrossSync.pytest + async def test_large_batch_write(self): + """ + Test that a large batch of mutations can be written + """ + import math + + num_mutations = 10_000 + flush_limit = 1000 + mutations = [self._make_mutation(count=1, size=1)] * num_mutations + async with self._make_one(flush_limit_mutation_count=flush_limit) as instance: + operation_mock = mock.Mock() + rpc_call_mock = CrossSync.Mock() + operation_mock().start = rpc_call_mock + CrossSync._MutateRowsOperation = operation_mock + for m in mutations: + await instance.append(m) + expected_calls = math.ceil(num_mutations / flush_limit) + assert rpc_call_mock.call_count == expected_calls + assert instance._entries_processed_since_last_raise == num_mutations + assert len(instance._staged_entries) == 0 diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py new file mode 100644 index 000000000..45d139182 --- /dev/null +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -0,0 +1,355 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import warnings +import pytest +import mock + +from itertools import zip_longest + +from google.cloud.bigtable_v2 import ReadRowsResponse + +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.row import Row + +from ...v2_client.test_row_merger import ReadRowsTest, TestFile + +from google.cloud.bigtable.data._cross_sync import CrossSync + + +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_read_rows_acceptance" + + +@CrossSync.convert_class( + sync_name="TestReadRowsAcceptance", +) +class TestReadRowsAcceptanceAsync: + @staticmethod + @CrossSync.convert + def _get_operation_class(): + return CrossSync._ReadRowsOperation + + @staticmethod + @CrossSync.convert + def _get_client_class(): + return CrossSync.DataClient + + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=(cell.labels[0] if cell.labels else ""), + ) + ) + return results + + @staticmethod + @CrossSync.convert + async def _coro_wrapper(stream): + return stream + + @CrossSync.convert + async def _process_chunks(self, *chunks): + @CrossSync.convert + async def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + results = [] + async for row in merger: + results.append(row) + return results + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + @CrossSync.pytest + async def test_row_merger_scenario(self, test_case: ReadRowsTest): + async def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + async for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + @CrossSync.pytest + async def test_read_rows_scenario(self, test_case: ReadRowsTest): + async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + def __iter__(self): + return self + + async def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise CrossSync.StopIteration + + def __next__(self): + return self.__anext__() + + def cancel(self): + pass + + return mock_stream(chunk_list) + + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # use emulator mode to avoid auth issues in CI + client = self._get_client_class()() + try: + table = client.get_table("instance", "table") + results = [] + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + # run once, then return error on retry + read_rows.return_value = _make_gapic_stream(test_case.chunks) + async for row in await table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + await client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @CrossSync.pytest + async def test_out_of_order_rows(self): + async def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + with pytest.raises(InvalidChunk): + async for _ in merger: + pass + + @CrossSync.pytest + async def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + @CrossSync.pytest + async def test_missing_family(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + @CrossSync.pytest + async def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + @CrossSync.pytest + async def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) diff --git a/tests/unit/data/execute_query/_async/_testing.py b/tests/unit/data/execute_query/_async/_testing.py deleted file mode 100644 index 5a7acbdd9..000000000 --- a/tests/unit/data/execute_query/_async/_testing.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa -from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes - - -try: - # async mock for python3.7-10 - from unittest.mock import Mock - from asyncio import coroutine - - def async_mock(return_value=None): - coro = Mock(name="CoroutineResult") - corofunc = Mock(name="CoroutineFunction", side_effect=coroutine(coro)) - corofunc.coro = coro - corofunc.coro.return_value = return_value - return corofunc - -except ImportError: - # async mock for python3.11 or later - from unittest.mock import AsyncMock - - def async_mock(return_value=None): - return AsyncMock(return_value=return_value) diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 5c577ed74..9bdf17c27 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -13,144 +13,171 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from unittest.mock import Mock -from mock import patch import pytest -from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( - ExecuteQueryIteratorAsync, -) +import concurrent.futures from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse -from ._testing import TYPE_INT, proto_rows_bytes, split_bytes_into_chunks, async_mock +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes + +from google.cloud.bigtable.data._cross_sync import CrossSync + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore -class MockIteratorAsync: +__CROSS_SYNC_OUTPUT__ = ( + "tests.unit.data.execute_query._sync_autogen.test_query_iterator" +) + + +@CrossSync.convert_class(sync_name="MockIterator") +class MockIterator: def __init__(self, values, delay=None): self._values = values self.idx = 0 self._delay = delay + @CrossSync.convert(sync_name="__iter__") def __aiter__(self): return self + @CrossSync.convert(sync_name="__next__") async def __anext__(self): if self.idx >= len(self._values): - raise StopAsyncIteration + raise CrossSync.StopIteration if self._delay is not None: - await asyncio.sleep(self._delay) + await CrossSync.sleep(self._delay) value = self._values[self.idx] self.idx += 1 return value -@pytest.fixture -def proto_byte_stream(): - proto_rows = [ - proto_rows_bytes({"int_value": 1}, {"int_value": 2}), - proto_rows_bytes({"int_value": 3}, {"int_value": 4}), - proto_rows_bytes({"int_value": 5}, {"int_value": 6}), - ] - - messages = [ - *split_bytes_into_chunks(proto_rows[0], num_chunks=2), - *split_bytes_into_chunks(proto_rows[1], num_chunks=3), - proto_rows[2], - ] - - stream = [ - ExecuteQueryResponse( - metadata={ - "proto_schema": { - "columns": [ - {"name": "test1", "type_": TYPE_INT}, - {"name": "test2", "type_": TYPE_INT}, - ] +@CrossSync.convert_class(sync_name="TestQueryIterator") +class TestQueryIteratorAsync: + @staticmethod + def _target_class(): + return CrossSync.ExecuteQueryIterator + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + @pytest.fixture + def proto_byte_stream(self): + proto_rows = [ + proto_rows_bytes({"int_value": 1}, {"int_value": 2}), + proto_rows_bytes({"int_value": 3}, {"int_value": 4}), + proto_rows_bytes({"int_value": 5}, {"int_value": 6}), + ] + + messages = [ + *split_bytes_into_chunks(proto_rows[0], num_chunks=2), + *split_bytes_into_chunks(proto_rows[1], num_chunks=3), + proto_rows[2], + ] + + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": { + "columns": [ + {"name": "test1", "type_": TYPE_INT}, + {"name": "test2", "type_": TYPE_INT}, + ] + } + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[0]}} + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[1]}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[2]}} + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[3]}} + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[4]}, + "resume_token": b"token2", + } + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[5]}, + "resume_token": b"token3", } - } - ), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[0]}}), - ExecuteQueryResponse( - results={ - "proto_rows_batch": {"batch_data": messages[1]}, - "resume_token": b"token1", - } - ), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[2]}}), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[3]}}), - ExecuteQueryResponse( - results={ - "proto_rows_batch": {"batch_data": messages[4]}, - "resume_token": b"token2", - } - ), - ExecuteQueryResponse( - results={ - "proto_rows_batch": {"batch_data": messages[5]}, - "resume_token": b"token3", - } - ), - ] - return stream - - -@pytest.mark.asyncio -async def test_iterator(proto_byte_stream): - client_mock = Mock() - - client_mock._register_instance = async_mock() - client_mock._remove_instance_registration = async_mock() - mock_async_iterator = MockIteratorAsync(proto_byte_stream) - iterator = None - - with patch( - "google.api_core.retry.retry_target_stream_async", - return_value=mock_async_iterator, - ): - iterator = ExecuteQueryIteratorAsync( - client=client_mock, - instance_id="test-instance", - app_profile_id="test_profile", - request_body={}, - attempt_timeout=10, - operation_timeout=10, - req_metadata=(), - retryable_excs=[], - ) - result = [] - async for value in iterator: - result.append(tuple(value)) - assert result == [(1, 2), (3, 4), (5, 6)] - - assert iterator.is_closed - client_mock._register_instance.assert_called_once() - client_mock._remove_instance_registration.assert_called_once() - - assert mock_async_iterator.idx == len(proto_byte_stream) - - -@pytest.mark.asyncio -async def test_iterator_awaits_metadata(proto_byte_stream): - client_mock = Mock() - - client_mock._register_instance = async_mock() - client_mock._remove_instance_registration = async_mock() - mock_async_iterator = MockIteratorAsync(proto_byte_stream) - iterator = None - with patch( - "google.api_core.retry.retry_target_stream_async", - return_value=mock_async_iterator, - ): - iterator = ExecuteQueryIteratorAsync( - client=client_mock, - instance_id="test-instance", - app_profile_id="test_profile", - request_body={}, - attempt_timeout=10, - operation_timeout=10, - req_metadata=(), - retryable_excs=[], - ) - - await iterator.metadata() - - assert mock_async_iterator.idx == 1 + ), + ] + return stream + + @CrossSync.pytest + async def test_iterator(self, proto_byte_stream): + client_mock = mock.Mock() + + client_mock._register_instance = CrossSync.Mock() + client_mock._remove_instance_registration = CrossSync.Mock() + client_mock._executor = concurrent.futures.ThreadPoolExecutor() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + + with mock.patch.object( + CrossSync, + "retry_target_stream", + return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + result = [] + async for value in iterator: + result.append(tuple(value)) + assert result == [(1, 2), (3, 4), (5, 6)] + + assert iterator.is_closed + client_mock._register_instance.assert_called_once() + client_mock._remove_instance_registration.assert_called_once() + + assert mock_async_iterator.idx == len(proto_byte_stream) + + @CrossSync.pytest + async def test_iterator_awaits_metadata(self, proto_byte_stream): + client_mock = mock.Mock() + + client_mock._register_instance = CrossSync.Mock() + client_mock._remove_instance_registration = CrossSync.Mock() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + with mock.patch.object( + CrossSync, + "retry_target_stream", + return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + + await iterator.metadata() + + assert mock_async_iterator.idx == 1 diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py deleted file mode 100644 index 7cb3c08dc..000000000 --- a/tests/unit/data/test_read_rows_acceptance.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import os -from itertools import zip_longest - -import pytest -import mock - -from google.cloud.bigtable_v2 import ReadRowsResponse - -from google.cloud.bigtable.data._async.client import BigtableDataClientAsync -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync -from google.cloud.bigtable.data.row import Row - -from ..v2_client.test_row_merger import ReadRowsTest, TestFile - - -def parse_readrows_acceptance_tests(): - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "./read-rows-acceptance-test.json") - - with open(filename) as json_file: - test_json = TestFile.from_json(json_file.read()) - return test_json.read_rows_tests - - -def extract_results_from_row(row: Row): - results = [] - for family, col, cells in row.items(): - for cell in cells: - results.append( - ReadRowsTest.Result( - row_key=row.row_key, - family_name=family, - qualifier=col, - timestamp_micros=cell.timestamp_ns // 1000, - value=cell.value, - label=(cell.labels[0] if cell.labels else ""), - ) - ) - return results - - -@pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description -) -@pytest.mark.asyncio -async def test_row_merger_scenario(test_case: ReadRowsTest): - async def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) - - try: - results = [] - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_scenerio_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - async for row in merger: - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - -@pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description -) -@pytest.mark.asyncio -async def test_read_rows_scenario(test_case: ReadRowsTest): - async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration - - def cancel(self): - pass - - return mock_stream(chunk_list) - - try: - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - # use emulator mode to avoid auth issues in CI - client = BigtableDataClientAsync() - table = client.get_table("instance", "table") - results = [] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: - # run once, then return error on retry - read_rows.return_value = _make_gapic_stream(test_case.chunks) - async for row in await table.read_rows_stream(query={}): - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - finally: - await client.close() - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - -@pytest.mark.asyncio -async def test_out_of_order_rows(): - async def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = b"b" - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - with pytest.raises(InvalidChunk): - async for _ in merger: - pass - - -@pytest.mark.asyncio -async def test_bare_reset(): - first_chunk = ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk( - row_key=b"a", family_name="f", qualifier=b"q", value=b"v" - ) - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, family_name="f") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, value=b"v") - ), - ) - - -@pytest.mark.asyncio -async def test_missing_family(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - qualifier=b"q", - timestamp_micros=1000, - value=b"v", - commit_row=True, - ) - ) - - -@pytest.mark.asyncio -async def test_mid_cell_row_key_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_family_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_qualifier_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_timestamp_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - timestamp_micros=2000, value=b"v", commit_row=True - ), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_labels_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), - ) - - -async def _coro_wrapper(stream): - return stream - - -async def _process_chunks(*chunks): - async def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = None - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - results = [] - async for row in merger: - results.append(row) - return results From 11935301b32872f264bab5b9afd2cd9ee654b4e2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Dec 2024 11:23:34 -0600 Subject: [PATCH 2/5] chore: use more verbose paths in hello snippets (#1045) --- samples/beam/requirements-test.txt | 2 +- samples/hello/async_main.py | 14 +++++++------- samples/hello/main.py | 8 ++++++-- samples/hello/requirements-test.txt | 2 +- samples/hello_happybase/requirements-test.txt | 2 +- samples/instanceadmin/requirements-test.txt | 2 +- samples/metricscaler/requirements-test.txt | 2 +- samples/quickstart/requirements-test.txt | 2 +- samples/quickstart_happybase/requirements-test.txt | 2 +- samples/snippets/data_client/requirements-test.txt | 2 +- samples/snippets/deletes/deletes_async_test.py | 8 ++++++++ samples/snippets/deletes/requirements-test.txt | 2 +- .../snippets/filters/filter_snippets_async_test.py | 8 ++++++++ samples/snippets/filters/requirements-test.txt | 2 +- samples/snippets/reads/requirements-test.txt | 2 +- samples/snippets/writes/requirements-test.txt | 2 +- samples/tableadmin/requirements-test.txt | 2 +- 17 files changed, 42 insertions(+), 22 deletions(-) diff --git a/samples/beam/requirements-test.txt b/samples/beam/requirements-test.txt index fe93bd52f..e079f8a60 100644 --- a/samples/beam/requirements-test.txt +++ b/samples/beam/requirements-test.txt @@ -1 +1 @@ -pytest==8.3.2 +pytest diff --git a/samples/hello/async_main.py b/samples/hello/async_main.py index 0130161e1..34159bedb 100644 --- a/samples/hello/async_main.py +++ b/samples/hello/async_main.py @@ -31,11 +31,11 @@ # [START bigtable_async_hw_imports] from google.cloud import bigtable from google.cloud.bigtable.data import row_filters -from google.cloud.bigtable.data import RowMutationEntry -from google.cloud.bigtable.data import SetCell -from google.cloud.bigtable.data import ReadRowsQuery # [END bigtable_async_hw_imports] +# use to ignore warnings +row_filters + async def main(project_id, instance_id, table_id): # [START bigtable_async_hw_connect] @@ -85,8 +85,8 @@ async def main(project_id, instance_id, table_id): # # https://cloud.google.com/bigtable/docs/schema-design row_key = "greeting{}".format(i).encode() - row_mutation = RowMutationEntry( - row_key, SetCell(column_family_id, column, value) + row_mutation = bigtable.data.RowMutationEntry( + row_key, bigtable.data.SetCell(column_family_id, column, value) ) mutations.append(row_mutation) await table.bulk_mutate_rows(mutations) @@ -95,7 +95,7 @@ async def main(project_id, instance_id, table_id): # [START bigtable_async_hw_create_filter] # Create a filter to only retrieve the most recent version of the cell # for each column across entire row. - row_filter = row_filters.CellsColumnLimitFilter(1) + row_filter = bigtable.data.row_filters.CellsColumnLimitFilter(1) # [END bigtable_async_hw_create_filter] # [START bigtable_async_hw_get_with_filter] @@ -112,7 +112,7 @@ async def main(project_id, instance_id, table_id): # [START bigtable_async_hw_scan_with_filter] # [START bigtable_async_hw_scan_all] print("Scanning for all greetings:") - query = ReadRowsQuery(row_filter=row_filter) + query = bigtable.data.ReadRowsQuery(row_filter=row_filter) async for row in await table.read_rows_stream(query): cell = row.cells[0] print(cell.value.decode("utf-8")) diff --git a/samples/hello/main.py b/samples/hello/main.py index 3e5078608..41124e826 100644 --- a/samples/hello/main.py +++ b/samples/hello/main.py @@ -36,6 +36,10 @@ # [END bigtable_hw_imports] +# use to avoid warnings +row_filters +column_family + def main(project_id, instance_id, table_id): # [START bigtable_hw_connect] @@ -52,7 +56,7 @@ def main(project_id, instance_id, table_id): print("Creating column family cf1 with Max Version GC rule...") # Create a column family with GC policy : most recent N versions # Define the GC policy to retain only the most recent 2 versions - max_versions_rule = column_family.MaxVersionsGCRule(2) + max_versions_rule = bigtable.column_family.MaxVersionsGCRule(2) column_family_id = "cf1" column_families = {column_family_id: max_versions_rule} if not table.exists(): @@ -93,7 +97,7 @@ def main(project_id, instance_id, table_id): # [START bigtable_hw_create_filter] # Create a filter to only retrieve the most recent version of the cell # for each column across entire row. - row_filter = row_filters.CellsColumnLimitFilter(1) + row_filter = bigtable.row_filters.CellsColumnLimitFilter(1) # [END bigtable_hw_create_filter] # [START bigtable_hw_get_with_filter] diff --git a/samples/hello/requirements-test.txt b/samples/hello/requirements-test.txt index fe93bd52f..e079f8a60 100644 --- a/samples/hello/requirements-test.txt +++ b/samples/hello/requirements-test.txt @@ -1 +1 @@ -pytest==8.3.2 +pytest diff --git a/samples/hello_happybase/requirements-test.txt b/samples/hello_happybase/requirements-test.txt index fe93bd52f..e079f8a60 100644 --- a/samples/hello_happybase/requirements-test.txt +++ b/samples/hello_happybase/requirements-test.txt @@ -1 +1 @@ -pytest==8.3.2 +pytest diff --git a/samples/instanceadmin/requirements-test.txt b/samples/instanceadmin/requirements-test.txt index fe93bd52f..e079f8a60 100644 --- a/samples/instanceadmin/requirements-test.txt +++ b/samples/instanceadmin/requirements-test.txt @@ -1 +1 @@ -pytest==8.3.2 +pytest diff --git a/samples/metricscaler/requirements-test.txt b/samples/metricscaler/requirements-test.txt index caf5f029c..13d734378 100644 --- a/samples/metricscaler/requirements-test.txt +++ b/samples/metricscaler/requirements-test.txt @@ -1,3 +1,3 @@ -pytest==8.3.2 +pytest mock==5.1.0 google-cloud-testutils diff --git a/samples/quickstart/requirements-test.txt b/samples/quickstart/requirements-test.txt index a63626120..ee4ba0186 100644 --- a/samples/quickstart/requirements-test.txt +++ b/samples/quickstart/requirements-test.txt @@ -1,2 +1,2 @@ -pytest==8.3.2 +pytest pytest-asyncio diff --git a/samples/quickstart_happybase/requirements-test.txt b/samples/quickstart_happybase/requirements-test.txt index fe93bd52f..55b033e90 100644 --- a/samples/quickstart_happybase/requirements-test.txt +++ b/samples/quickstart_happybase/requirements-test.txt @@ -1 +1 @@ -pytest==8.3.2 +pytest \ No newline at end of file diff --git a/samples/snippets/data_client/requirements-test.txt b/samples/snippets/data_client/requirements-test.txt index a63626120..ee4ba0186 100644 --- a/samples/snippets/data_client/requirements-test.txt +++ b/samples/snippets/data_client/requirements-test.txt @@ -1,2 +1,2 @@ -pytest==8.3.2 +pytest pytest-asyncio diff --git a/samples/snippets/deletes/deletes_async_test.py b/samples/snippets/deletes/deletes_async_test.py index 9408a8320..4fb4898e5 100644 --- a/samples/snippets/deletes/deletes_async_test.py +++ b/samples/snippets/deletes/deletes_async_test.py @@ -30,6 +30,14 @@ TABLE_ID = f"mobile-time-series-deletes-async-{str(uuid.uuid4())[:16]}" +@pytest.fixture(scope="module") +def event_loop(): + import asyncio + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + @pytest_asyncio.fixture(scope="module", autouse=True) async def table_id() -> AsyncGenerator[str, None]: with create_table_cm(PROJECT, BIGTABLE_INSTANCE, TABLE_ID, {"stats_summary": None, "cell_plan": None}, verbose=False): diff --git a/samples/snippets/deletes/requirements-test.txt b/samples/snippets/deletes/requirements-test.txt index a63626120..ee4ba0186 100644 --- a/samples/snippets/deletes/requirements-test.txt +++ b/samples/snippets/deletes/requirements-test.txt @@ -1,2 +1,2 @@ -pytest==8.3.2 +pytest pytest-asyncio diff --git a/samples/snippets/filters/filter_snippets_async_test.py b/samples/snippets/filters/filter_snippets_async_test.py index 124db8157..a3f83a6f2 100644 --- a/samples/snippets/filters/filter_snippets_async_test.py +++ b/samples/snippets/filters/filter_snippets_async_test.py @@ -34,6 +34,14 @@ TABLE_ID = f"mobile-time-series-filters-async-{str(uuid.uuid4())[:16]}" +@pytest.fixture(scope="module") +def event_loop(): + import asyncio + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + @pytest_asyncio.fixture(scope="module", autouse=True) async def table_id() -> AsyncGenerator[str, None]: with create_table_cm(PROJECT, BIGTABLE_INSTANCE, TABLE_ID, {"stats_summary": None, "cell_plan": None}): diff --git a/samples/snippets/filters/requirements-test.txt b/samples/snippets/filters/requirements-test.txt index a63626120..ee4ba0186 100644 --- a/samples/snippets/filters/requirements-test.txt +++ b/samples/snippets/filters/requirements-test.txt @@ -1,2 +1,2 @@ -pytest==8.3.2 +pytest pytest-asyncio diff --git a/samples/snippets/reads/requirements-test.txt b/samples/snippets/reads/requirements-test.txt index fe93bd52f..e079f8a60 100644 --- a/samples/snippets/reads/requirements-test.txt +++ b/samples/snippets/reads/requirements-test.txt @@ -1 +1 @@ -pytest==8.3.2 +pytest diff --git a/samples/snippets/writes/requirements-test.txt b/samples/snippets/writes/requirements-test.txt index 0f4b18778..5e15eb26f 100644 --- a/samples/snippets/writes/requirements-test.txt +++ b/samples/snippets/writes/requirements-test.txt @@ -1,2 +1,2 @@ backoff==2.2.1 -pytest==8.3.2 +pytest diff --git a/samples/tableadmin/requirements-test.txt b/samples/tableadmin/requirements-test.txt index 7f86b7bc4..a4c9e9c0b 100644 --- a/samples/tableadmin/requirements-test.txt +++ b/samples/tableadmin/requirements-test.txt @@ -1,2 +1,2 @@ -pytest==8.3.2 +pytest google-cloud-testutils==1.4.0 From f974823bf8a74c2f8b1bc69997b13bc1acaf8bef Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Dec 2024 15:41:19 -0600 Subject: [PATCH 3/5] feat: add generated sync client (#1017) --- .cross_sync/README.md | 2 +- docs/async_data_client/async_data_usage.rst | 18 - .../async_data_client.rst | 2 +- .../async_data_execute_query_iterator.rst | 0 .../async_data_mutations_batcher.rst | 0 .../async_data_table.rst | 0 .../common_data_exceptions.rst} | 0 .../common_data_execute_query_metadata.rst} | 0 .../common_data_execute_query_values.rst} | 0 .../common_data_mutations.rst} | 0 .../common_data_read_modify_write_rules.rst} | 0 .../common_data_read_rows_query.rst} | 0 .../common_data_row.rst} | 0 .../common_data_row_filters.rst} | 0 docs/data_client/data_client_usage.rst | 39 + docs/data_client/sync_data_client.rst | 6 + .../sync_data_execute_query_iterator.rst | 6 + .../sync_data_mutations_batcher.rst | 6 + docs/data_client/sync_data_table.rst | 6 + docs/index.rst | 4 +- docs/scripts/patch_devsite_toc.py | 9 +- google/cloud/bigtable/data/__init__.py | 18 +- .../bigtable/data/_async/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 11 +- .../bigtable/data/_async/mutations_batcher.py | 8 +- google/cloud/bigtable/data/_helpers.py | 5 +- .../data/_sync_autogen/_mutate_rows.py | 182 ++ .../bigtable/data/_sync_autogen/_read_rows.py | 304 ++ .../bigtable/data/_sync_autogen/client.py | 1234 +++++++ .../data/_sync_autogen/mutations_batcher.py | 449 +++ .../bigtable/data/execute_query/__init__.py | 5 + .../_async/execute_query_iterator.py | 2 + .../_sync_autogen/execute_query_iterator.py | 186 ++ noxfile.py | 16 +- .../handlers/client_handler_data_async.py | 2 +- tests/system/data/test_system_async.py | 2 +- tests/system/data/test_system_autogen.py | 828 +++++ tests/unit/data/_async/test__mutate_rows.py | 2 +- tests/unit/data/_async/test__read_rows.py | 1 + tests/unit/data/_async/test_client.py | 18 +- .../data/_async/test_mutations_batcher.py | 2 +- .../data/_async/test_read_rows_acceptance.py | 2 +- tests/unit/data/_sync_autogen/__init__.py | 0 .../data/_sync_autogen/test__mutate_rows.py | 307 ++ .../data/_sync_autogen/test__read_rows.py | 354 ++ tests/unit/data/_sync_autogen/test_client.py | 2889 +++++++++++++++++ .../_sync_autogen/test_mutations_batcher.py | 1078 ++++++ .../test_read_rows_acceptance.py | 328 ++ .../_async/test_query_iterator.py | 1 - .../execute_query/_sync_autogen/__init__.py | 0 .../_sync_autogen/test_query_iterator.py | 163 + .../test_execute_query_parameters_parsing.py | 2 +- tests/unit/data/test__helpers.py | 2 +- 54 files changed, 8448 insertions(+), 55 deletions(-) delete mode 100644 docs/async_data_client/async_data_usage.rst rename docs/{async_data_client => data_client}/async_data_client.rst (79%) rename docs/{async_data_client => data_client}/async_data_execute_query_iterator.rst (100%) rename docs/{async_data_client => data_client}/async_data_mutations_batcher.rst (100%) rename docs/{async_data_client => data_client}/async_data_table.rst (100%) rename docs/{async_data_client/async_data_exceptions.rst => data_client/common_data_exceptions.rst} (100%) rename docs/{async_data_client/async_data_execute_query_metadata.rst => data_client/common_data_execute_query_metadata.rst} (100%) rename docs/{async_data_client/async_data_execute_query_values.rst => data_client/common_data_execute_query_values.rst} (100%) rename docs/{async_data_client/async_data_mutations.rst => data_client/common_data_mutations.rst} (100%) rename docs/{async_data_client/async_data_read_modify_write_rules.rst => data_client/common_data_read_modify_write_rules.rst} (100%) rename docs/{async_data_client/async_data_read_rows_query.rst => data_client/common_data_read_rows_query.rst} (100%) rename docs/{async_data_client/async_data_row.rst => data_client/common_data_row.rst} (100%) rename docs/{async_data_client/async_data_row_filters.rst => data_client/common_data_row_filters.rst} (100%) create mode 100644 docs/data_client/data_client_usage.rst create mode 100644 docs/data_client/sync_data_client.rst create mode 100644 docs/data_client/sync_data_execute_query_iterator.rst create mode 100644 docs/data_client/sync_data_mutations_batcher.rst create mode 100644 docs/data_client/sync_data_table.rst create mode 100644 google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py create mode 100644 google/cloud/bigtable/data/_sync_autogen/_read_rows.py create mode 100644 google/cloud/bigtable/data/_sync_autogen/client.py create mode 100644 google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py create mode 100644 google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py create mode 100644 tests/system/data/test_system_autogen.py create mode 100644 tests/unit/data/_sync_autogen/__init__.py create mode 100644 tests/unit/data/_sync_autogen/test__mutate_rows.py create mode 100644 tests/unit/data/_sync_autogen/test__read_rows.py create mode 100644 tests/unit/data/_sync_autogen/test_client.py create mode 100644 tests/unit/data/_sync_autogen/test_mutations_batcher.py create mode 100644 tests/unit/data/_sync_autogen/test_read_rows_acceptance.py create mode 100644 tests/unit/data/execute_query/_sync_autogen/__init__.py create mode 100644 tests/unit/data/execute_query/_sync_autogen/test_query_iterator.py diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 4214e0d78..18a9aafdf 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -62,7 +62,7 @@ CrossSync provides a set of annotations to mark up async classes, to guide the g ### Code Generation -Generation can be initiated using `python .cross_sync/generate.py .` +Generation can be initiated using `nox -s generate_sync` from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` annotation, and generate a sync version of classes marked with `@CrossSync.convert_sync` at the output path. diff --git a/docs/async_data_client/async_data_usage.rst b/docs/async_data_client/async_data_usage.rst deleted file mode 100644 index 61d5837fd..000000000 --- a/docs/async_data_client/async_data_usage.rst +++ /dev/null @@ -1,18 +0,0 @@ -Async Data Client -================= - -.. toctree:: - :maxdepth: 2 - - async_data_client - async_data_table - async_data_mutations_batcher - async_data_read_rows_query - async_data_row - async_data_row_filters - async_data_mutations - async_data_read_modify_write_rules - async_data_exceptions - async_data_execute_query_iterator - async_data_execute_query_values - async_data_execute_query_metadata diff --git a/docs/async_data_client/async_data_client.rst b/docs/data_client/async_data_client.rst similarity index 79% rename from docs/async_data_client/async_data_client.rst rename to docs/data_client/async_data_client.rst index 0e1d9e23e..2ddcc090c 100644 --- a/docs/async_data_client/async_data_client.rst +++ b/docs/data_client/async_data_client.rst @@ -7,6 +7,6 @@ Bigtable Data Client Async performance benefits, the codebase should be designed to be async from the ground up. -.. autoclass:: google.cloud.bigtable.data._async.client.BigtableDataClientAsync +.. autoclass:: google.cloud.bigtable.data.BigtableDataClientAsync :members: :show-inheritance: diff --git a/docs/async_data_client/async_data_execute_query_iterator.rst b/docs/data_client/async_data_execute_query_iterator.rst similarity index 100% rename from docs/async_data_client/async_data_execute_query_iterator.rst rename to docs/data_client/async_data_execute_query_iterator.rst diff --git a/docs/async_data_client/async_data_mutations_batcher.rst b/docs/data_client/async_data_mutations_batcher.rst similarity index 100% rename from docs/async_data_client/async_data_mutations_batcher.rst rename to docs/data_client/async_data_mutations_batcher.rst diff --git a/docs/async_data_client/async_data_table.rst b/docs/data_client/async_data_table.rst similarity index 100% rename from docs/async_data_client/async_data_table.rst rename to docs/data_client/async_data_table.rst diff --git a/docs/async_data_client/async_data_exceptions.rst b/docs/data_client/common_data_exceptions.rst similarity index 100% rename from docs/async_data_client/async_data_exceptions.rst rename to docs/data_client/common_data_exceptions.rst diff --git a/docs/async_data_client/async_data_execute_query_metadata.rst b/docs/data_client/common_data_execute_query_metadata.rst similarity index 100% rename from docs/async_data_client/async_data_execute_query_metadata.rst rename to docs/data_client/common_data_execute_query_metadata.rst diff --git a/docs/async_data_client/async_data_execute_query_values.rst b/docs/data_client/common_data_execute_query_values.rst similarity index 100% rename from docs/async_data_client/async_data_execute_query_values.rst rename to docs/data_client/common_data_execute_query_values.rst diff --git a/docs/async_data_client/async_data_mutations.rst b/docs/data_client/common_data_mutations.rst similarity index 100% rename from docs/async_data_client/async_data_mutations.rst rename to docs/data_client/common_data_mutations.rst diff --git a/docs/async_data_client/async_data_read_modify_write_rules.rst b/docs/data_client/common_data_read_modify_write_rules.rst similarity index 100% rename from docs/async_data_client/async_data_read_modify_write_rules.rst rename to docs/data_client/common_data_read_modify_write_rules.rst diff --git a/docs/async_data_client/async_data_read_rows_query.rst b/docs/data_client/common_data_read_rows_query.rst similarity index 100% rename from docs/async_data_client/async_data_read_rows_query.rst rename to docs/data_client/common_data_read_rows_query.rst diff --git a/docs/async_data_client/async_data_row.rst b/docs/data_client/common_data_row.rst similarity index 100% rename from docs/async_data_client/async_data_row.rst rename to docs/data_client/common_data_row.rst diff --git a/docs/async_data_client/async_data_row_filters.rst b/docs/data_client/common_data_row_filters.rst similarity index 100% rename from docs/async_data_client/async_data_row_filters.rst rename to docs/data_client/common_data_row_filters.rst diff --git a/docs/data_client/data_client_usage.rst b/docs/data_client/data_client_usage.rst new file mode 100644 index 000000000..f5bbac278 --- /dev/null +++ b/docs/data_client/data_client_usage.rst @@ -0,0 +1,39 @@ +Data Client +=========== + +Sync Surface +------------ + +.. toctree:: + :maxdepth: 3 + + sync_data_client + sync_data_table + sync_data_mutations_batcher + sync_data_execute_query_iterator + +Async Surface +------------- + +.. toctree:: + :maxdepth: 3 + + async_data_client + async_data_table + async_data_mutations_batcher + async_data_execute_query_iterator + +Common Classes +-------------- + +.. toctree:: + :maxdepth: 3 + + common_data_read_rows_query + common_data_row + common_data_row_filters + common_data_mutations + common_data_read_modify_write_rules + common_data_exceptions + common_data_execute_query_values + common_data_execute_query_metadata diff --git a/docs/data_client/sync_data_client.rst b/docs/data_client/sync_data_client.rst new file mode 100644 index 000000000..cf7c00dad --- /dev/null +++ b/docs/data_client/sync_data_client.rst @@ -0,0 +1,6 @@ +Bigtable Data Client +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: google.cloud.bigtable.data.BigtableDataClient + :members: + :show-inheritance: diff --git a/docs/data_client/sync_data_execute_query_iterator.rst b/docs/data_client/sync_data_execute_query_iterator.rst new file mode 100644 index 000000000..6eb9f84db --- /dev/null +++ b/docs/data_client/sync_data_execute_query_iterator.rst @@ -0,0 +1,6 @@ +Execute Query Iterator +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: google.cloud.bigtable.data.execute_query.ExecuteQueryIterator + :members: + :show-inheritance: diff --git a/docs/data_client/sync_data_mutations_batcher.rst b/docs/data_client/sync_data_mutations_batcher.rst new file mode 100644 index 000000000..2b7d1bfe0 --- /dev/null +++ b/docs/data_client/sync_data_mutations_batcher.rst @@ -0,0 +1,6 @@ +Mutations Batcher +~~~~~~~~~~~~~~~~~ + +.. automodule:: google.cloud.bigtable.data._sync_autogen.mutations_batcher + :members: + :show-inheritance: diff --git a/docs/data_client/sync_data_table.rst b/docs/data_client/sync_data_table.rst new file mode 100644 index 000000000..95c91eb27 --- /dev/null +++ b/docs/data_client/sync_data_table.rst @@ -0,0 +1,6 @@ +Table +~~~~~ + +.. autoclass:: google.cloud.bigtable.data.Table + :members: + :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index 4204e981d..c7f9721f3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,10 +5,10 @@ Client Types ------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 3 + data_client/data_client_usage classic_client/usage - async_data_client/async_data_usage Changelog diff --git a/docs/scripts/patch_devsite_toc.py b/docs/scripts/patch_devsite_toc.py index 456d0af7b..5889300d2 100644 --- a/docs/scripts/patch_devsite_toc.py +++ b/docs/scripts/patch_devsite_toc.py @@ -117,7 +117,8 @@ def __init__(self, dir_name, index_file_name): continue # bail when toc indented block is done if not line.startswith(" ") and not line.startswith("\t"): - break + in_toc = False + continue # extract entries self.items.append(self.extract_toc_entry(line.strip())) @@ -194,9 +195,7 @@ def validate_toc(toc_file_path, expected_section_list, added_sections): # Add secrtions for the async_data_client and classic_client directories toc_path = "_build/html/docfx_yaml/toc.yml" custom_sections = [ - TocSection( - dir_name="async_data_client", index_file_name="async_data_usage.rst" - ), + TocSection(dir_name="data_client", index_file_name="data_client_usage.rst"), TocSection(dir_name="classic_client", index_file_name="usage.rst"), ] add_sections(toc_path, custom_sections) @@ -210,7 +209,7 @@ def validate_toc(toc_file_path, expected_section_list, added_sections): "bigtable APIs", "Changelog", "Multiprocessing", - "Async Data Client", + "Data Client", "Classic Client", ], added_sections=custom_sections, diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 43ea69fdf..15f9bc167 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -17,8 +17,10 @@ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.data._sync_autogen.client import BigtableDataClient +from google.cloud.bigtable.data._sync_autogen.client import Table +from google.cloud.bigtable.data._sync_autogen.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -52,13 +54,22 @@ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable_v2.services.bigtable.client import ( + BigtableClient, +) +from google.cloud.bigtable.data._sync_autogen._read_rows import _ReadRowsOperation +from google.cloud.bigtable.data._sync_autogen._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._cross_sync import CrossSync CrossSync.add_mapping("GapicClient", BigtableAsyncClient) +CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableClient) CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) +CrossSync._Sync_Impl.add_mapping("_ReadRowsOperation", _ReadRowsOperation) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) +CrossSync._Sync_Impl.add_mapping("_MutateRowsOperation", _MutateRowsOperation) CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) - +CrossSync._Sync_Impl.add_mapping("MutationsBatcher", MutationsBatcher) __version__: str = package_version.__version__ @@ -66,6 +77,9 @@ "BigtableDataClientAsync", "TableAsync", "MutationsBatcherAsync", + "BigtableDataClient", + "Table", + "MutationsBatcher", "RowKeySamples", "ReadRowsQuery", "RowRange", diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index c5795c464..bf618bf04 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index c02b3750d..6d2fa3a7d 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index d560d7e1e..c7cc0de6b 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -85,8 +85,10 @@ ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE else: + from typing import Iterable # noqa: F401 from grpc import insecure_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore + from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE if TYPE_CHECKING: @@ -100,6 +102,13 @@ from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( ExecuteQueryIteratorAsync, ) + else: + from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( # noqa: F401 + MutationsBatcher, + ) + from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( # noqa: F401 + ExecuteQueryIterator, + ) __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.client" diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 65070c880..6e15bb5f3 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import Sequence, TYPE_CHECKING +from typing import Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -250,7 +250,7 @@ def __init__( ) # used by sync class to manage flush_internal tasks self._sync_flush_executor = ( - concurrent.futures.ThreadPoolExecutor(max_workers=1) + concurrent.futures.ThreadPoolExecutor(max_workers=4) if not CrossSync.is_async else None ) @@ -305,7 +305,7 @@ async def append(self, mutation_entry: RowMutationEntry): # TODO: return a future to track completion of this entry if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") - if isinstance(mutation_entry, Mutation): # type: ignore + if isinstance(cast(Mutation, mutation_entry), Mutation): raise ValueError( f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" ) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index bd1c09d52..4c45e5c1c 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: import grpc from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import Table """ Helper functions used in various places in the library. @@ -120,7 +121,7 @@ def _retry_exception_factory( def _get_timeouts( operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, - table: "TableAsync", + table: "TableAsync" | "Table", ) -> tuple[float, float]: """ Convert passed in timeout values to floats, using table defaults if necessary. @@ -207,7 +208,7 @@ def _get_error_type( def _get_retryable_errors( call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, - table: "TableAsync", + table: "TableAsync" | "Table", ) -> list[type[Exception]]: """ Convert passed in retryable error codes to a list of exception types. diff --git a/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py new file mode 100644 index 000000000..8e8c5ca89 --- /dev/null +++ b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py @@ -0,0 +1,182 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +from typing import Sequence, TYPE_CHECKING +import functools +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.cloud.bigtable.data.exceptions as bt_exceptions +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import _EntryWithProto +from google.cloud.bigtable.data._cross_sync import CrossSync + +if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.cloud.bigtable_v2.services.bigtable.client import ( + BigtableClient as GapicClientType, + ) + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType + + +class _MutateRowsOperation: + """ + MutateRowsOperation manages the logic of sending a set of row mutations, + and retrying on failed entries. It manages this using the _run_attempt + function, which attempts to mutate all outstanding entries, and raises + _MutateRowsIncomplete if any retryable errors are encountered. + + Errors are exposed as a MutationsExceptionGroup, which contains a list of + exceptions organized by the related failed mutation entries. + + Args: + gapic_client: the client to use for the mutate_rows call + table: the table associated with the request + mutation_entries: a list of RowMutationEntry objects to send to the server + operation_timeout: the timeout to use for the entire operation, in seconds. + attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. + If not specified, the request will run until operation_timeout is reached. + """ + + def __init__( + self, + gapic_client: GapicClientType, + table: TableType, + mutation_entries: list["RowMutationEntry"], + operation_timeout: float, + attempt_timeout: float | None, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + total_mutations = sum((len(entry.mutations) for entry in mutation_entries)) + if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + raise ValueError( + f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." + ) + self._gapic_fn = functools.partial( + gapic_client.mutate_rows, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + retry=None, + ) + self.is_retryable = retries.if_exception_type( + *retryable_exceptions, bt_exceptions._MutateRowsIncomplete + ) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + self._operation = lambda: CrossSync._Sync_Impl.retry_target( + self._run_attempt, + self.is_retryable, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + self.timeout_generator = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] + self.remaining_indices = list(range(len(self.mutations))) + self.errors: dict[int, list[Exception]] = {} + + def start(self): + """Start the operation, and run until completion + + Raises: + MutationsExceptionGroup: if any mutations failed""" + try: + self._operation() + except Exception as exc: + incomplete_indices = self.remaining_indices.copy() + for idx in incomplete_indices: + self._handle_entry_error(idx, exc) + finally: + all_errors: list[Exception] = [] + for idx, exc_list in self.errors.items(): + if len(exc_list) == 0: + raise core_exceptions.ClientError( + f"Mutation {idx} failed with no associated errors" + ) + elif len(exc_list) == 1: + cause_exc = exc_list[0] + else: + cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) + entry = self.mutations[idx].entry + all_errors.append( + bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) + ) + if all_errors: + raise bt_exceptions.MutationsExceptionGroup( + all_errors, len(self.mutations) + ) + + def _run_attempt(self): + """Run a single attempt of the mutate_rows rpc. + + Raises: + _MutateRowsIncomplete: if there are failed mutations eligible for + retry after the attempt is complete + GoogleAPICallError: if the gapic rpc fails""" + request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] + active_request_indices = { + req_idx: orig_idx + for (req_idx, orig_idx) in enumerate(self.remaining_indices) + } + self.remaining_indices = [] + if not request_entries: + return + try: + result_generator = self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, + ) + for result_list in result_generator: + for result in result_list.entries: + orig_idx = active_request_indices[result.index] + entry_error = core_exceptions.from_grpc_status( + result.status.code, + result.status.message, + details=result.status.details, + ) + if result.status.code != 0: + self._handle_entry_error(orig_idx, entry_error) + elif orig_idx in self.errors: + del self.errors[orig_idx] + del active_request_indices[result.index] + except Exception as exc: + for idx in active_request_indices.values(): + self._handle_entry_error(idx, exc) + raise + if self.remaining_indices: + raise bt_exceptions._MutateRowsIncomplete + + def _handle_entry_error(self, idx: int, exc: Exception): + """Add an exception to the list of exceptions for a given mutation index, + and add the index to the list of remaining indices if the exception is + retryable. + + Args: + idx: the index of the mutation that failed + exc: the exception to add to the list""" + entry = self.mutations[idx].entry + self.errors.setdefault(idx, []).append(exc) + if ( + entry.is_idempotent() + and self.is_retryable(exc) + and (idx not in self.remaining_indices) + ): + self.remaining_indices.append(idx) diff --git a/google/cloud/bigtable/data/_sync_autogen/_read_rows.py b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py new file mode 100644 index 000000000..92619c6a4 --- /dev/null +++ b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py @@ -0,0 +1,304 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +from typing import Sequence, TYPE_CHECKING +from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB +from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB +from google.cloud.bigtable_v2.types import RowSet as RowSetPB +from google.cloud.bigtable_v2.types import RowRange as RowRangePB +from google.cloud.bigtable.data.row import Row, Cell +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data.exceptions import _ResetRow +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.api_core import retry as retries +from google.api_core.retry import exponential_sleep_generator +from google.cloud.bigtable.data._cross_sync import CrossSync + +if TYPE_CHECKING: + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType + + +class _ReadRowsOperation: + """ + ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream + into a stream of Row objects. + + ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse + and turns them into a stream of Row objects using an internal + StateMachine. + + ReadRowsOperation(request, client) handles row merging logic end-to-end, including + performing retries on stream errors. + + Args: + query: The query to execute + table: The table to send the request to + operation_timeout: The total time to allow for the operation, in seconds + attempt_timeout: The time to allow for each individual attempt, in seconds + retryable_exceptions: A list of exceptions that should trigger a retry + """ + + __slots__ = ( + "attempt_timeout_gen", + "operation_timeout", + "request", + "table", + "_predicate", + "_last_yielded_row_key", + "_remaining_count", + ) + + def __init__( + self, + query: ReadRowsQuery, + table: TableType, + operation_timeout: float, + attempt_timeout: float, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + self.attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.operation_timeout = operation_timeout + if isinstance(query, dict): + self.request = ReadRowsRequestPB( + **query, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + ) + else: + self.request = query._to_pb(table) + self.table = table + self._predicate = retries.if_exception_type(*retryable_exceptions) + self._last_yielded_row_key: bytes | None = None + self._remaining_count: int | None = self.request.rows_limit or None + + def start_operation(self) -> CrossSync._Sync_Impl.Iterable[Row]: + """Start the read_rows operation, retrying on retryable errors. + + Yields: + Row: The next row in the stream""" + return CrossSync._Sync_Impl.retry_target_stream( + self._read_rows_attempt, + self._predicate, + exponential_sleep_generator(0.01, 60, multiplier=2), + self.operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def _read_rows_attempt(self) -> CrossSync._Sync_Impl.Iterable[Row]: + """Attempt a single read_rows rpc call. + This function is intended to be wrapped by retry logic, + which will call this function until it succeeds or + a non-retryable error is raised. + + Yields: + Row: The next row in the stream""" + if self._last_yielded_row_key is not None: + try: + self.request.rows = self._revise_request_rowset( + row_set=self.request.rows, + last_seen_row_key=self._last_yielded_row_key, + ) + except _RowSetComplete: + return self.merge_rows(None) + if self._remaining_count is not None: + self.request.rows_limit = self._remaining_count + if self._remaining_count == 0: + return self.merge_rows(None) + gapic_stream = self.table.client._gapic_client.read_rows( + self.request, timeout=next(self.attempt_timeout_gen), retry=None + ) + chunked_stream = self.chunk_stream(gapic_stream) + return self.merge_rows(chunked_stream) + + def chunk_stream( + self, + stream: CrossSync._Sync_Impl.Awaitable[ + CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB] + ], + ) -> CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk]: + """process chunks out of raw read_rows stream + + Args: + stream: the raw read_rows stream from the gapic client + Yields: + ReadRowsResponsePB.CellChunk: the next chunk in the stream""" + for resp in stream: + resp = resp._pb + if resp.last_scanned_row_key: + if ( + self._last_yielded_row_key is not None + and resp.last_scanned_row_key <= self._last_yielded_row_key + ): + raise InvalidChunk("last scanned out of order") + self._last_yielded_row_key = resp.last_scanned_row_key + current_key = None + for c in resp.chunks: + if current_key is None: + current_key = c.row_key + if current_key is None: + raise InvalidChunk("first chunk is missing a row key") + elif ( + self._last_yielded_row_key + and current_key <= self._last_yielded_row_key + ): + raise InvalidChunk("row keys should be strictly increasing") + yield c + if c.reset_row: + current_key = None + elif c.commit_row: + self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") + current_key = None + + @staticmethod + def merge_rows( + chunks: CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk] | None, + ) -> CrossSync._Sync_Impl.Iterable[Row]: + """Merge chunks into rows + + Args: + chunks: the chunk stream to merge + Yields: + Row: the next row in the stream""" + if chunks is None: + return + it = chunks.__iter__() + while True: + try: + c = it.__next__() + except CrossSync._Sync_Impl.StopIteration: + return + row_key = c.row_key + if not row_key: + raise InvalidChunk("first row chunk is missing key") + cells = [] + family: str | None = None + qualifier: bytes | None = None + try: + while True: + if c.reset_row: + raise _ResetRow(c) + k = c.row_key + f = c.family_name.value + q = c.qualifier.value if c.HasField("qualifier") else None + if k and k != row_key: + raise InvalidChunk("unexpected new row key") + if f: + family = f + if q is not None: + qualifier = q + else: + raise InvalidChunk("new family without qualifier") + elif family is None: + raise InvalidChunk("missing family") + elif q is not None: + if family is None: + raise InvalidChunk("new qualifier without family") + qualifier = q + elif qualifier is None: + raise InvalidChunk("missing qualifier") + ts = c.timestamp_micros + labels = c.labels if c.labels else [] + value = c.value + if c.value_size > 0: + buffer = [value] + while c.value_size > 0: + c = it.__next__() + t = c.timestamp_micros + cl = c.labels + k = c.row_key + if ( + c.HasField("family_name") + and c.family_name.value != family + ): + raise InvalidChunk("family changed mid cell") + if ( + c.HasField("qualifier") + and c.qualifier.value != qualifier + ): + raise InvalidChunk("qualifier changed mid cell") + if t and t != ts: + raise InvalidChunk("timestamp changed mid cell") + if cl and cl != labels: + raise InvalidChunk("labels changed mid cell") + if k and k != row_key: + raise InvalidChunk("row key changed mid cell") + if c.reset_row: + raise _ResetRow(c) + buffer.append(c.value) + value = b"".join(buffer) + cells.append( + Cell(value, row_key, family, qualifier, ts, list(labels)) + ) + if c.commit_row: + yield Row(row_key, cells) + break + c = it.__next__() + except _ResetRow as e: + c = e.chunk + if ( + c.row_key + or c.HasField("family_name") + or c.HasField("qualifier") + or c.timestamp_micros + or c.labels + or c.value + ): + raise InvalidChunk("reset row with data") + continue + except CrossSync._Sync_Impl.StopIteration: + raise InvalidChunk("premature end of stream") + + @staticmethod + def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSetPB: + """Revise the rows in the request to avoid ones we've already processed. + + Args: + row_set: the row set from the request + last_seen_row_key: the last row key encountered + Returns: + RowSetPB: the new rowset after adusting for the last seen key + Raises: + _RowSetComplete: if there are no rows left to process after the revision""" + if row_set is None or (not row_set.row_ranges and (not row_set.row_keys)): + last_seen = last_seen_row_key + return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) + adjusted_keys: list[bytes] = [ + k for k in row_set.row_keys if k > last_seen_row_key + ] + adjusted_ranges: list[RowRangePB] = [] + for row_range in row_set.row_ranges: + end_key = row_range.end_key_closed or row_range.end_key_open or None + if end_key is None or end_key > last_seen_row_key: + new_range = RowRangePB(row_range) + start_key = row_range.start_key_closed or row_range.start_key_open + if start_key is None or start_key <= last_seen_row_key: + new_range.start_key_open = last_seen_row_key + adjusted_ranges.append(new_range) + if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0: + raise _RowSetComplete() + return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py new file mode 100644 index 000000000..37e192147 --- /dev/null +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -0,0 +1,1234 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING +import time +import warnings +import random +import os +import concurrent.futures +from functools import partial +from grpc import Channel +from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType +from google.cloud.bigtable.data.execute_query.metadata import SqlType +from google.cloud.bigtable.data.execute_query._parameters_formatting import ( + _format_execute_query_params, +) +from google.cloud.bigtable_v2.services.bigtable.transports.base import ( + DEFAULT_CLIENT_INFO, +) +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.api_core import retry as retries +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.api_core.exceptions import Aborted +import google.auth.credentials +import google.auth._default +from google.api_core import client_options as client_options_lib +from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import FailedQueryShardError +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _WarmedInstanceKey +from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_error_type +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry +from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.data.row_filters import RowFilter +from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter +from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data._cross_sync import CrossSync +from typing import Iterable +from grpc import insecure_channel +from google.cloud.bigtable_v2.services.bigtable.transports import ( + BigtableGrpcTransport as TransportType, +) +from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE + +if TYPE_CHECKING: + from google.cloud.bigtable.data._helpers import RowKeySamples + from google.cloud.bigtable.data._helpers import ShardedQuery + from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( + MutationsBatcher, + ) + from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( + ExecuteQueryIterator, + ) + + +@CrossSync._Sync_Impl.add_mapping_decorator("DataClient") +class BigtableDataClient(ClientWithProject): + def __init__( + self, + *, + project: str | None = None, + credentials: google.auth.credentials.Credentials | None = None, + client_options: dict[str, Any] + | "google.api_core.client_options.ClientOptions" + | None = None, + **kwargs, + ): + """Create a client instance for the Bigtable Data API + + + + Args: + project: the project which the client acts on behalf of. + If not passed, falls back to the default inferred + from the environment. + credentials: + Thehe OAuth2 Credentials to use for this + client. If not passed (and if no ``_http`` object is + passed), falls back to the default inferred from the + environment. + client_options: + Client options used to set user options + on the client. API Endpoint should be set through client_options. + Raises: + """ + if "pool_size" in kwargs: + warnings.warn("pool_size no longer supported") + client_info = DEFAULT_CLIENT_INFO + client_info.client_library_version = self._client_version() + if type(client_options) is dict: + client_options = client_options_lib.from_dict(client_options) + client_options = cast( + Optional[client_options_lib.ClientOptions], client_options + ) + custom_channel = None + self._emulator_host = os.getenv(BIGTABLE_EMULATOR) + if self._emulator_host is not None: + warnings.warn( + "Connecting to Bigtable emulator at {}".format(self._emulator_host), + RuntimeWarning, + stacklevel=2, + ) + custom_channel = insecure_channel(self._emulator_host) + if credentials is None: + credentials = google.auth.credentials.AnonymousCredentials() + if project is None: + project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + ClientWithProject.__init__( + self, + credentials=credentials, + project=project, + client_options=client_options, + ) + self._gapic_client = CrossSync._Sync_Impl.GapicClient( + credentials=credentials, + client_options=client_options, + client_info=client_info, + transport=lambda *args, **kwargs: TransportType( + *args, **kwargs, channel=custom_channel + ), + ) + self._is_closed = CrossSync._Sync_Impl.Event() + self.transport = cast(TransportType, self._gapic_client.transport) + self._active_instances: Set[_WarmedInstanceKey] = set() + self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} + self._channel_init_time = time.monotonic() + self._channel_refresh_task: CrossSync._Sync_Impl.Task[None] | None = None + self._executor = ( + concurrent.futures.ThreadPoolExecutor() + if not CrossSync._Sync_Impl.is_async + else None + ) + if self._emulator_host is None: + try: + self._start_background_channel_refresh() + except RuntimeError: + warnings.warn( + f"{self.__class__.__name__} should be started in an asyncio event loop. Channel refresh will not be started", + RuntimeWarning, + stacklevel=2, + ) + + @staticmethod + def _client_version() -> str: + """Helper function to return the client version string for this client""" + version_str = f"{google.cloud.bigtable.__version__}-data" + return version_str + + def _start_background_channel_refresh(self) -> None: + """Starts a background task to ping and warm grpc channel + + Raises: + None""" + if ( + not self._channel_refresh_task + and (not self._emulator_host) + and (not self._is_closed.is_set()) + ): + CrossSync._Sync_Impl.verify_async_event_loop() + self._channel_refresh_task = CrossSync._Sync_Impl.create_task( + self._manage_channel, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh", + ) + + def close(self, timeout: float | None = 2.0): + """Cancel all background tasks""" + self._is_closed.set() + if self._channel_refresh_task is not None: + self._channel_refresh_task.cancel() + CrossSync._Sync_Impl.wait([self._channel_refresh_task], timeout=timeout) + self.transport.close() + if self._executor: + self._executor.shutdown(wait=False) + self._channel_refresh_task = None + + def _ping_and_warm_instances( + self, + instance_key: _WarmedInstanceKey | None = None, + channel: Channel | None = None, + ) -> list[BaseException | None]: + """Prepares the backend for requests on a channel + + Pings each Bigtable instance registered in `_active_instances` on the client + + Args: + instance_key: if provided, only warm the instance associated with the key + channel: grpc channel to warm. If none, warms `self.transport.grpc_channel` + Returns: + list[BaseException | None]: sequence of results or exceptions from the ping requests + """ + channel = channel or self.transport.grpc_channel + instance_list = ( + [instance_key] if instance_key is not None else self._active_instances + ) + ping_rpc = channel.unary_unary( + "/google.bigtable.v2.Bigtable/PingAndWarm", + request_serializer=PingAndWarmRequest.serialize, + ) + partial_list = [ + partial( + ping_rpc, + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], + wait_for_ready=True, + ) + for (instance_name, table_name, app_profile_id) in instance_list + ] + result_list = CrossSync._Sync_Impl.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor + ) + return [r or None for r in result_list] + + def _manage_channel( + self, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + grace_period: float = 60 * 10, + ) -> None: + """Background task that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + + Runs continuously until the client is closed + + Args: + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds""" + first_refresh = self._channel_init_time + random.uniform( + refresh_interval_min, refresh_interval_max + ) + next_sleep = max(first_refresh - time.monotonic(), 0) + if next_sleep > 0: + self._ping_and_warm_instances(channel=self.transport.grpc_channel) + while not self._is_closed.is_set(): + CrossSync._Sync_Impl.event_wait( + self._is_closed, next_sleep, async_break_early=False + ) + if self._is_closed.is_set(): + break + start_timestamp = time.monotonic() + old_channel = self.transport.grpc_channel + new_channel = self.transport.create_channel() + self._ping_and_warm_instances(channel=new_channel) + self.transport._grpc_channel = new_channel + if grace_period: + self._is_closed.wait(grace_period) + old_channel.close() + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) + + def _register_instance( + self, instance_id: str, owner: Table | ExecuteQueryIterator + ) -> None: + """Registers an instance with the client, and warms the channel for the instance + The client will periodically refresh grpc channel used to make + requests, and new channels will be warmed for each registered instance + Channels will not be refreshed unless at least one instance is registered + + Args: + instance_id: id of the instance to register. + owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration""" + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + self._instance_owners.setdefault(instance_key, set()).add(id(owner)) + if instance_key not in self._active_instances: + self._active_instances.add(instance_key) + if self._channel_refresh_task: + self._ping_and_warm_instances(instance_key) + else: + self._start_background_channel_refresh() + + def _remove_instance_registration( + self, instance_id: str, owner: Table | "ExecuteQueryIterator" + ) -> bool: + """Removes an instance from the client's registered instances, to prevent + warming new channels for the instance + + If instance_id is not registered, or is still in use by other tables, returns False + + Args: + instance_id: id of the instance to remove + owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + Returns: + bool: True if instance was removed, else False""" + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + owner_list = self._instance_owners.get(instance_key, set()) + try: + owner_list.remove(id(owner)) + if len(owner_list) == 0: + self._active_instances.remove(instance_key) + return True + except KeyError: + return False + + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: + """Returns a table instance for making data API requests. All arguments are passed + directly to the Table constructor. + + + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Returns: + Table: a table instance for making data API requests + Raises: + None""" + return Table(self, instance_id, table_id, *args, **kwargs) + + def execute_query( + self, + query: str, + instance_id: str, + *, + parameters: dict[str, ExecuteQueryValueType] | None = None, + parameter_types: dict[str, SqlType.Type] | None = None, + app_profile_id: str | None = None, + operation_timeout: float = 600, + attempt_timeout: float | None = 20, + retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + ) -> "ExecuteQueryIterator": + """Executes an SQL query on an instance. + Returns an iterator to asynchronously stream back columns from selected rows. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: Query to be run on Bigtable instance. The query can use ``@param`` + placeholders to use parameter interpolation on the server. Values for all + parameters should be provided in ``parameters``. Types of parameters are + inferred but should be provided in ``parameter_types`` if the inference is + not possible (i.e. when value can be None, an empty list or an empty dict). + instance_id: The Bigtable instance ID to perform the query on. + instance_id is combined with the client's project to fully + specify the instance. + parameters: Dictionary with values for all parameters used in the ``query``. + parameter_types: Dictionary with types of parameters used in the ``query``. + Required to contain entries only for parameters whose type cannot be + detected automatically (i.e. the value can be None, an empty list or + an empty dict). + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to 600 seconds. + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the 20 seconds. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + Returns: + ExecuteQueryIterator: an asynchronous iterator that yields rows returned by the query + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + warnings.warn( + "ExecuteQuery is in preview and may change in the future.", + category=RuntimeWarning, + ) + retryable_excs = [_get_error_type(e) for e in retryable_errors] + pb_params = _format_execute_query_params(parameters, parameter_types) + instance_name = self._gapic_client.instance_path(self.project, instance_id) + request_body = { + "instance_name": instance_name, + "app_profile_id": app_profile_id, + "query": query, + "params": pb_params, + "proto_format": {}, + } + return CrossSync._Sync_Impl.ExecuteQueryIterator( + self, + instance_id, + app_profile_id, + request_body, + attempt_timeout, + operation_timeout, + retryable_excs=retryable_excs, + ) + + def __enter__(self): + self._start_background_channel_refresh() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + self._gapic_client.__exit__(exc_type, exc_val, exc_tb) + + +@CrossSync._Sync_Impl.add_mapping_decorator("Table") +class Table: + """ + Main Data API surface + + Table object maintains table_id, and app_profile_id context, and passes them with + each call + """ + + def __init__( + self, + client: BigtableDataClient, + instance_id: str, + table_id: str, + app_profile_id: str | None = None, + *, + default_read_rows_operation_timeout: float = 600, + default_read_rows_attempt_timeout: float | None = 20, + default_mutate_rows_operation_timeout: float = 600, + default_mutate_rows_attempt_timeout: float | None = 60, + default_operation_timeout: float = 60, + default_attempt_timeout: float | None = 20, + default_read_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + default_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + ): + """Initialize a Table instance + + + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Raises: + None""" + _validate_timeouts( + default_operation_timeout, default_attempt_timeout, allow_none=True + ) + _validate_timeouts( + default_read_rows_operation_timeout, + default_read_rows_attempt_timeout, + allow_none=True, + ) + _validate_timeouts( + default_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout, + allow_none=True, + ) + self.client = client + self.instance_id = instance_id + self.instance_name = self.client._gapic_client.instance_path( + self.client.project, instance_id + ) + self.table_id = table_id + self.table_name = self.client._gapic_client.table_path( + self.client.project, instance_id, table_id + ) + self.app_profile_id = app_profile_id + self.default_operation_timeout = default_operation_timeout + self.default_attempt_timeout = default_attempt_timeout + self.default_read_rows_operation_timeout = default_read_rows_operation_timeout + self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout + self.default_mutate_rows_operation_timeout = ( + default_mutate_rows_operation_timeout + ) + self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout + self.default_read_rows_retryable_errors = ( + default_read_rows_retryable_errors or () + ) + self.default_mutate_rows_retryable_errors = ( + default_mutate_rows_retryable_errors or () + ) + self.default_retryable_errors = default_retryable_errors or () + try: + self._register_instance_future = CrossSync._Sync_Impl.create_task( + self.client._register_instance, + self.instance_id, + self, + sync_executor=self.client._executor, + ) + except RuntimeError as e: + raise RuntimeError( + f"{self.__class__.__name__} must be created within an async event loop context." + ) from e + + def read_rows_stream( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Iterable[Row]: + """Read a set of rows from the table, based on the specified query. + Returns an iterator to asynchronously stream back row data. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors + Returns: + Iterable[Row]: an asynchronous iterator that yields rows returned by the query + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + row_merger = CrossSync._Sync_Impl._ReadRowsOperation( + query, + self, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_exceptions=retryable_excs, + ) + return row_merger.start_operation() + + def read_rows( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """Read a set of rows from the table, based on the specified query. + Retruns results as a list of Row objects when the request is complete. + For streamed results, use read_rows_stream. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + If None, defaults to the Table's default_read_rows_attempt_timeout, + or the operation_timeout if that is also None. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + list[Row]: a list of Rows returned by the query + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + row_generator = self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return [row for row in row_generator] + + def read_row( + self, + row_key: str | bytes, + *, + row_filter: RowFilter | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Row | None: + """Read a single row from the table, based on the specified key. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + Row | None: a Row object if the row exists, otherwise None + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + if len(results) == 0: + return None + return results[0] + + def read_rows_sharded( + self, + sharded_query: ShardedQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """Runs a sharded query in parallel, then return the results in a single list. + Results will be returned in the order of the input queries. + + This function is intended to be run on the results on a query.shard() call. + For example:: + + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(...) + shard_queries = query.shard(table_shard_keys) + results = await table.read_rows_sharded(shard_queries) + + Args: + sharded_query: a sharded query to execute + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + list[Row]: a list of Rows returned by the query + Raises: + ShardedReadRowsExceptionGroup: if any of the queries failed + ValueError: if the query_list is empty""" + if not sharded_query: + raise ValueError("empty sharded_query") + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + rpc_timeout_generator = _attempt_timeout_generator( + operation_timeout, operation_timeout + ) + concurrency_sem = CrossSync._Sync_Impl.Semaphore(_CONCURRENCY_LIMIT) + + def read_rows_with_semaphore(query): + with concurrency_sem: + shard_timeout = next(rpc_timeout_generator) + if shard_timeout <= 0: + raise DeadlineExceeded( + "Operation timeout exceeded before starting query" + ) + return self.read_rows( + query, + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), + retryable_errors=retryable_errors, + ) + + routine_list = [ + partial(read_rows_with_semaphore, query) for query in sharded_query + ] + batch_result = CrossSync._Sync_Impl.gather_partials( + routine_list, return_exceptions=True, sync_executor=self.client._executor + ) + error_dict = {} + shard_idx = 0 + results_list = [] + for result in batch_result: + if isinstance(result, Exception): + error_dict[shard_idx] = result + elif isinstance(result, BaseException): + raise result + else: + results_list.extend(result) + shard_idx += 1 + if error_dict: + raise ShardedReadRowsExceptionGroup( + [ + FailedQueryShardError(idx, sharded_query[idx], e) + for (idx, e) in error_dict.items() + ], + results_list, + len(sharded_query), + ) + return results_list + + def row_exists( + self, + row_key: str | bytes, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> bool: + """Return a boolean indicating whether the specified row exists in the table. + uses the filters: chain(limit cells per row = 1, strip value) + + Args: + row_key: the key of the row to check + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + bool: a bool indicating whether the row exists + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + strip_filter = StripValueTransformerFilter(flag=True) + limit_filter = CellsRowLimitFilter(1) + chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) + query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return len(results) > 0 + + def sample_row_keys( + self, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> RowKeySamples: + """Return a set of RowKeySamples that delimit contiguous sections of the table of + approximately equal size + + RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that + can be parallelized across multiple backend nodes read_rows and read_rows_stream + requests will call sample_row_keys internally for this purpose when sharding is enabled + + RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of + row_keys, along with offset positions in the table + + Args: + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget.i + Defaults to the Table's default_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_retryable_errors. + Returns: + RowKeySamples: a set of RowKeySamples the delimit contiguous sections of the table + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + predicate = retries.if_exception_type(*retryable_excs) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + + def execute_rpc(): + results = self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + retry=None, + ) + return [(s.row_key, s.offset_bytes) for s in results] + + return CrossSync._Sync_Impl.retry_target( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def mutations_batcher( + self, + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ) -> "MutationsBatcher": + """Returns a new mutations batcher instance. + + Can be used to iteratively add mutations that are flushed as a group, + to avoid excess network calls + + Args: + flush_interval: Automatically flush every flush_interval seconds. If None, + a table default will be used + flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + flow_control_max_mutation_count: Maximum number of inflight mutations. + flow_control_max_bytes: Maximum number of inflight bytes. + batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + Defaults to the Table's default_mutate_rows_operation_timeout + batch_attempt_timeout: timeout for each individual request, in seconds. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + Returns: + MutationsBatcher: a MutationsBatcher context manager that can batch requests + """ + return CrossSync._Sync_Impl.MutationsBatcher( + self, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_mutation_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=batch_operation_timeout, + batch_attempt_timeout=batch_attempt_timeout, + batch_retryable_errors=batch_retryable_errors, + ) + + def mutate_row( + self, + row_key: str | bytes, + mutations: list[Mutation] | Mutation, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ): + """Mutates a row atomically. + + Cells already present in the row are left unchanged unless explicitly changed + by ``mutation``. + + Idempotent operations (i.e, all mutations have an explicit timestamp) will be + retried on server failure. Non-idempotent operations will not. + + Args: + row_key: the row to apply mutations to + mutations: the set of mutations to apply to the row + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Only idempotent mutations will be retried. Defaults to the Table's + default_retryable_errors. + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing all + GoogleAPIError exceptions from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be + safely retried. + ValueError: if invalid arguments are provided""" + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + if not mutations: + raise ValueError("No mutations provided") + mutations_list = mutations if isinstance(mutations, list) else [mutations] + if all((mutation.is_idempotent() for mutation in mutations_list)): + predicate = retries.if_exception_type( + *_get_retryable_errors(retryable_errors, self) + ) + else: + predicate = retries.if_exception_type() + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + target = partial( + self.client._gapic_client.mutate_row, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + retry=None, + ) + return CrossSync._Sync_Impl.retry_target( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def bulk_mutate_rows( + self, + mutation_entries: list[RowMutationEntry], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + """Applies mutations for multiple rows in a single batched request. + + Each individual RowMutationEntry is applied atomically, but separate entries + may be applied in arbitrary order (even for entries targetting the same row) + In total, the row_mutations can contain at most 100000 individual mutations + across all entries + + Idempotent entries (i.e., entries with mutations with explicit timestamps) + will be retried on failure. Non-idempotent will not, and will reported in a + raised exception group + + Args: + mutation_entries: the batches of mutations to apply + Each entry will be applied atomically, but entries will be applied + in arbitrary order + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_mutate_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors + Raises: + MutationsExceptionGroup: if one or more mutations fails + Contains details about any failed entries in .exceptions + ValueError: if invalid arguments are provided""" + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + operation = CrossSync._Sync_Impl._MutateRowsOperation( + self.client._gapic_client, + self, + mutation_entries, + operation_timeout, + attempt_timeout, + retryable_exceptions=retryable_excs, + ) + operation.start() + + def check_and_mutate_row( + self, + row_key: str | bytes, + predicate: RowFilter | None, + *, + true_case_mutations: Mutation | list[Mutation] | None = None, + false_case_mutations: Mutation | list[Mutation] | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> bool: + """Mutates a row atomically based on the output of a predicate filter + + Non-idempotent operation: will not be retried + + Args: + row_key: the key of the row to mutate + predicate: the filter to be applied to the contents of the specified row. + Depending on whether or not any results are yielded, + either true_case_mutations or false_case_mutations will be executed. + If None, checks that the row contains any values at all. + true_case_mutations: + Changes to be atomically applied to the specified row if + predicate yields at least one cell when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + false_case_mutations is empty, and at most 100000. + false_case_mutations: + Changes to be atomically applied to the specified row if + predicate_filter does not yield any cells when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + `true_case_mutations` is empty, and at most 100000. + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. Defaults to the Table's default_operation_timeout + Returns: + bool indicating whether the predicate was true or false + Raises: + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call""" + (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) + if true_case_mutations is not None and ( + not isinstance(true_case_mutations, list) + ): + true_case_mutations = [true_case_mutations] + true_case_list = [m._to_pb() for m in true_case_mutations or []] + if false_case_mutations is not None and ( + not isinstance(false_case_mutations, list) + ): + false_case_mutations = [false_case_mutations] + false_case_list = [m._to_pb() for m in false_case_mutations or []] + result = self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=operation_timeout, + retry=None, + ) + return result.predicate_matched + + def read_modify_write_row( + self, + row_key: str | bytes, + rules: ReadModifyWriteRule | list[ReadModifyWriteRule], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> Row: + """Reads and modifies a row atomically according to input ReadModifyWriteRules, + and returns the contents of all modified cells + + The new value for the timestamp is the greater of the existing timestamp or + the current server time. + + Non-idempotent operation: will not be retried + + Args: + row_key: the key of the row to apply read/modify/write rules to + rules: A rule or set of rules to apply to the row. + Rules are applied in order, meaning that earlier rules will affect the + results of later ones. + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. + Defaults to the Table's default_operation_timeout. + Returns: + Row: a Row containing cell data that was modified as part of the operation + Raises: + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call + ValueError: if invalid arguments are provided""" + (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + if rules is not None and (not isinstance(rules, list)): + rules = [rules] + if not rules: + raise ValueError("rules must contain at least one item") + result = self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=operation_timeout, + retry=None, + ) + return Row._from_pb(result.row) + + def close(self): + """Called to close the Table instance and release any resources held by it.""" + if self._register_instance_future: + self._register_instance_future.cancel() + self.client._remove_instance_registration(self.instance_id, self) + + def __enter__(self): + """Implement async context manager protocol + + Ensure registration task has time to run, so that + grpc channels will be warmed for the specified instance""" + if self._register_instance_future: + self._register_instance_future + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Implement async context manager protocol + + Unregister this instance with the client, so that + grpc channels will no longer be warmed""" + self.close() diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py new file mode 100644 index 000000000..2e4237b74 --- /dev/null +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -0,0 +1,449 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +from typing import Sequence, TYPE_CHECKING, cast +import atexit +import warnings +from collections import deque +import concurrent.futures +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._cross_sync import CrossSync + +if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType +_MB_SIZE = 1024 * 1024 + + +@CrossSync._Sync_Impl.add_mapping_decorator("_FlowControl") +class _FlowControl: + """ + Manages flow control for batched mutations. Mutations are registered against + the FlowControl object before being sent, which will block if size or count + limits have reached capacity. As mutations completed, they are removed from + the FlowControl object, which will notify any blocked requests that there + is additional capacity. + + Flow limits are not hard limits. If a single mutation exceeds the configured + limits, it will be allowed as a single batch when the capacity is available. + + Args: + max_mutation_count: maximum number of mutations to send in a single rpc. + This corresponds to individual mutations in a single RowMutationEntry. + max_mutation_bytes: maximum number of bytes to send in a single rpc. + Raises: + ValueError: if max_mutation_count or max_mutation_bytes is less than 0 + """ + + def __init__(self, max_mutation_count: int, max_mutation_bytes: int): + self._max_mutation_count = max_mutation_count + self._max_mutation_bytes = max_mutation_bytes + if self._max_mutation_count < 1: + raise ValueError("max_mutation_count must be greater than 0") + if self._max_mutation_bytes < 1: + raise ValueError("max_mutation_bytes must be greater than 0") + self._capacity_condition = CrossSync._Sync_Impl.Condition() + self._in_flight_mutation_count = 0 + self._in_flight_mutation_bytes = 0 + + def _has_capacity(self, additional_count: int, additional_size: int) -> bool: + """Checks if there is capacity to send a new entry with the given size and count + + FlowControl limits are not hard limits. If a single mutation exceeds + the configured flow limits, it will be sent in a single batch when + previous batches have completed. + + Args: + additional_count: number of mutations in the pending entry + additional_size: size of the pending entry + Returns: + bool: True if there is capacity to send the pending entry, False otherwise + """ + acceptable_size = max(self._max_mutation_bytes, additional_size) + acceptable_count = max(self._max_mutation_count, additional_count) + new_size = self._in_flight_mutation_bytes + additional_size + new_count = self._in_flight_mutation_count + additional_count + return new_size <= acceptable_size and new_count <= acceptable_count + + def remove_from_flow( + self, mutations: RowMutationEntry | list[RowMutationEntry] + ) -> None: + """Removes mutations from flow control. This method should be called once + for each mutation that was sent to add_to_flow, after the corresponding + operation is complete. + + Args: + mutations: mutation or list of mutations to remove from flow control""" + if not isinstance(mutations, list): + mutations = [mutations] + total_count = sum((len(entry.mutations) for entry in mutations)) + total_size = sum((entry.size() for entry in mutations)) + self._in_flight_mutation_count -= total_count + self._in_flight_mutation_bytes -= total_size + with self._capacity_condition: + self._capacity_condition.notify_all() + + def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): + """Generator function that registers mutations with flow control. As mutations + are accepted into the flow control, they are yielded back to the caller, + to be sent in a batch. If the flow control is at capacity, the generator + will block until there is capacity available. + + Args: + mutations: list mutations to break up into batches + Yields: + list[RowMutationEntry]: + list of mutations that have reserved space in the flow control. + Each batch contains at least one mutation.""" + if not isinstance(mutations, list): + mutations = [mutations] + start_idx = 0 + end_idx = 0 + while end_idx < len(mutations): + start_idx = end_idx + batch_mutation_count = 0 + with self._capacity_condition: + while end_idx < len(mutations): + next_entry = mutations[end_idx] + next_size = next_entry.size() + next_count = len(next_entry.mutations) + if ( + self._has_capacity(next_count, next_size) + and batch_mutation_count + next_count + <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + ): + end_idx += 1 + batch_mutation_count += next_count + self._in_flight_mutation_bytes += next_size + self._in_flight_mutation_count += next_count + elif start_idx != end_idx: + break + else: + self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) + yield mutations[start_idx:end_idx] + + +class MutationsBatcher: + """ + Allows users to send batches using context manager API: + + Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining + to use as few network requests as required + + Will automatically flush the batcher: + - every flush_interval seconds + - after queue size reaches flush_limit_mutation_count + - after queue reaches flush_limit_bytes + - when batcher is closed or destroyed + + Args: + table: Table to preform rpc calls + flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + flow_control_max_mutation_count: Maximum number of inflight mutations. + flow_control_max_bytes: Maximum number of inflight bytes. + batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. + batch_attempt_timeout: timeout for each individual request, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + """ + + def __init__( + self, + table: TableType, + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + (self._operation_timeout, self._attempt_timeout) = _get_timeouts( + batch_operation_timeout, batch_attempt_timeout, table + ) + self._retryable_errors: list[type[Exception]] = _get_retryable_errors( + batch_retryable_errors, table + ) + self._closed = CrossSync._Sync_Impl.Event() + self._table = table + self._staged_entries: list[RowMutationEntry] = [] + (self._staged_count, self._staged_bytes) = (0, 0) + self._flow_control = CrossSync._Sync_Impl._FlowControl( + flow_control_max_mutation_count, flow_control_max_bytes + ) + self._flush_limit_bytes = flush_limit_bytes + self._flush_limit_count = ( + flush_limit_mutation_count + if flush_limit_mutation_count is not None + else float("inf") + ) + self._sync_rpc_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=8) + if not CrossSync._Sync_Impl.is_async + else None + ) + self._sync_flush_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=4) + if not CrossSync._Sync_Impl.is_async + else None + ) + self._flush_timer = CrossSync._Sync_Impl.create_task( + self._timer_routine, flush_interval, sync_executor=self._sync_flush_executor + ) + self._flush_jobs: set[CrossSync._Sync_Impl.Future[None]] = set() + self._entries_processed_since_last_raise: int = 0 + self._exceptions_since_last_raise: int = 0 + self._exception_list_limit: int = 10 + self._oldest_exceptions: list[Exception] = [] + self._newest_exceptions: deque[Exception] = deque( + maxlen=self._exception_list_limit + ) + atexit.register(self._on_exit) + + def _timer_routine(self, interval: float | None) -> None: + """Set up a background task to flush the batcher every interval seconds + + If interval is None, an empty future is returned + + Args: + flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed.""" + if not interval or interval <= 0: + return None + while not self._closed.is_set(): + CrossSync._Sync_Impl.event_wait( + self._closed, timeout=interval, async_break_early=False + ) + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() + + def append(self, mutation_entry: RowMutationEntry): + """Add a new set of mutations to the internal queue + + Args: + mutation_entry: new entry to add to flush queue + Raises: + RuntimeError: if batcher is closed + ValueError: if an invalid mutation type is added""" + if self._closed.is_set(): + raise RuntimeError("Cannot append to closed MutationsBatcher") + if isinstance(cast(Mutation, mutation_entry), Mutation): + raise ValueError( + f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" + ) + self._staged_entries.append(mutation_entry) + self._staged_count += len(mutation_entry.mutations) + self._staged_bytes += mutation_entry.size() + if ( + self._staged_count >= self._flush_limit_count + or self._staged_bytes >= self._flush_limit_bytes + ): + self._schedule_flush() + CrossSync._Sync_Impl.yield_to_event_loop() + + def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: + """Update the flush task to include the latest staged entries + + Returns: + Future[None] | None: + future representing the background task, if started""" + if self._staged_entries: + (entries, self._staged_entries) = (self._staged_entries, []) + (self._staged_count, self._staged_bytes) = (0, 0) + new_task = CrossSync._Sync_Impl.create_task( + self._flush_internal, entries, sync_executor=self._sync_flush_executor + ) + if not new_task.done(): + self._flush_jobs.add(new_task) + new_task.add_done_callback(self._flush_jobs.remove) + return new_task + return None + + def _flush_internal(self, new_entries: list[RowMutationEntry]): + """Flushes a set of mutations to the server, and updates internal state + + Args: + new_entries list of RowMutationEntry objects to flush""" + in_process_requests: list[ + CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] + ] = [] + for batch in self._flow_control.add_to_flow(new_entries): + batch_task = CrossSync._Sync_Impl.create_task( + self._execute_mutate_rows, batch, sync_executor=self._sync_rpc_executor + ) + in_process_requests.append(batch_task) + found_exceptions = self._wait_for_batch_results(*in_process_requests) + self._entries_processed_since_last_raise += len(new_entries) + self._add_exceptions(found_exceptions) + + def _execute_mutate_rows( + self, batch: list[RowMutationEntry] + ) -> list[FailedMutationEntryError]: + """Helper to execute mutation operation on a batch + + Args: + batch: list of RowMutationEntry objects to send to server + timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. + If not given, will use table defaults + Returns: + list[FailedMutationEntryError]: + list of FailedMutationEntryError objects for mutations that failed. + FailedMutationEntryError objects will not contain index information""" + try: + operation = CrossSync._Sync_Impl._MutateRowsOperation( + self._table.client._gapic_client, + self._table, + batch, + operation_timeout=self._operation_timeout, + attempt_timeout=self._attempt_timeout, + retryable_exceptions=self._retryable_errors, + ) + operation.start() + except MutationsExceptionGroup as e: + for subexc in e.exceptions: + subexc.index = None + return list(e.exceptions) + finally: + self._flow_control.remove_from_flow(batch) + return [] + + def _add_exceptions(self, excs: list[Exception]): + """Add new list of exceptions to internal store. To avoid unbounded memory, + the batcher will store the first and last _exception_list_limit exceptions, + and discard any in between. + + Args: + excs: list of exceptions to add to the internal store""" + self._exceptions_since_last_raise += len(excs) + if excs and len(self._oldest_exceptions) < self._exception_list_limit: + addition_count = self._exception_list_limit - len(self._oldest_exceptions) + self._oldest_exceptions.extend(excs[:addition_count]) + excs = excs[addition_count:] + if excs: + self._newest_exceptions.extend(excs[-self._exception_list_limit :]) + + def _raise_exceptions(self): + """Raise any unreported exceptions from background flush operations + + Raises: + MutationsExceptionGroup: exception group with all unreported exceptions""" + if self._oldest_exceptions or self._newest_exceptions: + (oldest, self._oldest_exceptions) = (self._oldest_exceptions, []) + newest = list(self._newest_exceptions) + self._newest_exceptions.clear() + (entry_count, self._entries_processed_since_last_raise) = ( + self._entries_processed_since_last_raise, + 0, + ) + (exc_count, self._exceptions_since_last_raise) = ( + self._exceptions_since_last_raise, + 0, + ) + raise MutationsExceptionGroup.from_truncated_lists( + first_list=oldest, + last_list=newest, + total_excs=exc_count, + entry_count=entry_count, + ) + + def __enter__(self): + """Allow use of context manager API""" + return self + + def __exit__(self, exc_type, exc, tb): + """Allow use of context manager API. + + Flushes the batcher and cleans up resources.""" + self.close() + + @property + def closed(self) -> bool: + """Returns: + - True if the batcher is closed, False otherwise""" + return self._closed.is_set() + + def close(self): + """Flush queue and clean up resources""" + self._closed.set() + self._flush_timer.cancel() + self._schedule_flush() + if self._sync_flush_executor: + with self._sync_flush_executor: + self._sync_flush_executor.shutdown(wait=True) + if self._sync_rpc_executor: + with self._sync_rpc_executor: + self._sync_rpc_executor.shutdown(wait=True) + CrossSync._Sync_Impl.wait([*self._flush_jobs, self._flush_timer]) + atexit.unregister(self._on_exit) + self._raise_exceptions() + + def _on_exit(self): + """Called when program is exited. Raises warning if unflushed mutations remain""" + if not self._closed.is_set() and self._staged_entries: + warnings.warn( + f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." + ) + + @staticmethod + def _wait_for_batch_results( + *tasks: CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] + | CrossSync._Sync_Impl.Future[None], + ) -> list[Exception]: + """Takes in a list of futures representing _execute_mutate_rows tasks, + waits for them to complete, and returns a list of errors encountered. + + Args: + *tasks: futures representing _execute_mutate_rows or _flush_internal tasks + Returns: + list[Exception]: + list of Exceptions encountered by any of the tasks. Errors are expected + to be FailedMutationEntryError, representing a failed mutation operation. + If a task fails with a different exception, it will be included in the + output list. Successful tasks will not be represented in the output list. + """ + if not tasks: + return [] + exceptions: list[Exception] = [] + for task in tasks: + try: + exc_list = task.result() + if exc_list: + for exc in exc_list: + exc.index = None + exceptions.extend(exc_list) + except Exception as e: + exceptions.append(e) + return exceptions diff --git a/google/cloud/bigtable/data/execute_query/__init__.py b/google/cloud/bigtable/data/execute_query/__init__.py index 0ff258365..31fd5e3cc 100644 --- a/google/cloud/bigtable/data/execute_query/__init__.py +++ b/google/cloud/bigtable/data/execute_query/__init__.py @@ -15,6 +15,9 @@ from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( ExecuteQueryIteratorAsync, ) +from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( + ExecuteQueryIterator, +) from google.cloud.bigtable.data.execute_query.metadata import ( Metadata, ProtoMetadata, @@ -28,6 +31,7 @@ from google.cloud.bigtable.data._cross_sync import CrossSync CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) +CrossSync._Sync_Impl.add_mapping("ExecuteQueryIterator", ExecuteQueryIterator) __all__ = [ "ExecuteQueryValueType", @@ -37,4 +41,5 @@ "Metadata", "ProtoMetadata", "ExecuteQueryIteratorAsync", + "ExecuteQueryIterator", ] diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index ba82bbcca..66f264610 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -46,6 +46,8 @@ if TYPE_CHECKING: if CrossSync.is_async: from google.cloud.bigtable.data import BigtableDataClientAsync as DataClientType + else: + from google.cloud.bigtable.data import BigtableDataClient as DataClientType __CROSS_SYNC_OUTPUT__ = ( "google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator" diff --git a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py new file mode 100644 index 000000000..854148ff3 --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py @@ -0,0 +1,186 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +from typing import Any, Dict, Optional, Sequence, Tuple, TYPE_CHECKING +from google.api_core import retry as retries +from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor +from google.cloud.bigtable.data._helpers import ( + _attempt_timeout_generator, + _retry_exception_factory, +) +from google.cloud.bigtable.data.exceptions import InvalidExecuteQueryResponse +from google.cloud.bigtable.data.execute_query.values import QueryResultRow +from google.cloud.bigtable.data.execute_query.metadata import Metadata, ProtoMetadata +from google.cloud.bigtable.data.execute_query._reader import ( + _QueryResultRowReader, + _Reader, +) +from google.cloud.bigtable_v2.types.bigtable import ( + ExecuteQueryRequest as ExecuteQueryRequestPB, +) +from google.cloud.bigtable.data._cross_sync import CrossSync + +if TYPE_CHECKING: + from google.cloud.bigtable.data import BigtableDataClient as DataClientType + + +class ExecuteQueryIterator: + def __init__( + self, + client: DataClientType, + instance_id: str, + app_profile_id: Optional[str], + request_body: Dict[str, Any], + attempt_timeout: float | None, + operation_timeout: float, + req_metadata: Sequence[Tuple[str, str]] = (), + retryable_excs: Sequence[type[Exception]] = (), + ) -> None: + """Collects responses from ExecuteQuery requests and parses them into QueryResultRows. + + It is **not thread-safe**. It should not be used by multiple threads. + + Args: + client: bigtable client + instance_id: id of the instance on which the query is executed + request_body: dict representing the body of the ExecuteQueryRequest + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget + req_metadata: metadata used while sending the gRPC request + retryable_excs: a list of errors that will be retried if encountered. + Raises: + None""" + self._table_name = None + self._app_profile_id = app_profile_id + self._client = client + self._instance_id = instance_id + self._byte_cursor = _ByteCursor[ProtoMetadata]() + self._reader: _Reader[QueryResultRow] = _QueryResultRowReader(self._byte_cursor) + self._result_generator = self._next_impl() + self._register_instance_task = None + self._is_closed = False + self._request_body = request_body + self._attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self._stream = CrossSync._Sync_Impl.retry_target_stream( + self._make_request_with_resume_token, + retries.if_exception_type(*retryable_excs), + retries.exponential_sleep_generator(0.01, 60, multiplier=2), + operation_timeout, + exception_factory=_retry_exception_factory, + ) + self._req_metadata = req_metadata + try: + self._register_instance_task = CrossSync._Sync_Impl.create_task( + self._client._register_instance, + instance_id, + self, + sync_executor=self._client._executor, + ) + except RuntimeError as e: + raise RuntimeError( + f"{self.__class__.__name__} must be created within an async event loop context." + ) from e + + @property + def is_closed(self) -> bool: + """Returns True if the iterator is closed, False otherwise.""" + return self._is_closed + + @property + def app_profile_id(self) -> Optional[str]: + """Returns the app_profile_id of the iterator.""" + return self._app_profile_id + + @property + def table_name(self) -> Optional[str]: + """Returns the table_name of the iterator.""" + return self._table_name + + def _make_request_with_resume_token(self): + """perfoms the rpc call using the correct resume token.""" + resume_token = self._byte_cursor.prepare_for_new_request() + request = ExecuteQueryRequestPB( + {**self._request_body, "resume_token": resume_token} + ) + return self._client._gapic_client.execute_query( + request, + timeout=next(self._attempt_timeout_gen), + metadata=self._req_metadata, + retry=None, + ) + + def _fetch_metadata(self) -> None: + """If called before the first response was recieved, the first response + is retrieved as part of this call.""" + if self._byte_cursor.metadata is None: + metadata_msg = self._stream.__next__() + self._byte_cursor.consume_metadata(metadata_msg) + + def _next_impl(self) -> CrossSync._Sync_Impl.Iterator[QueryResultRow]: + """Generator wrapping the response stream which parses the stream results + and returns full `QueryResultRow`s.""" + self._fetch_metadata() + for response in self._stream: + try: + bytes_to_parse = self._byte_cursor.consume(response) + if bytes_to_parse is None: + continue + results = self._reader.consume(bytes_to_parse) + if results is None: + continue + except ValueError as e: + raise InvalidExecuteQueryResponse( + "Invalid ExecuteQuery response received" + ) from e + for result in results: + yield result + self.close() + + def __next__(self) -> QueryResultRow: + if self._is_closed: + raise CrossSync._Sync_Impl.StopIteration + return self._result_generator.__next__() + + def __iter__(self): + return self + + def metadata(self) -> Optional[Metadata]: + """Returns query metadata from the server or None if the iterator was + explicitly closed.""" + if self._is_closed: + return None + if self._byte_cursor.metadata is None: + try: + self._fetch_metadata() + except CrossSync._Sync_Impl.StopIteration: + return None + return self._byte_cursor.metadata + + def close(self) -> None: + """Cancel all background tasks. Should be called all rows were processed.""" + if self._is_closed: + return + self._is_closed = True + if self._register_instance_task is not None: + self._register_instance_task.cancel() + self._client._remove_instance_registration(self._instance_id, self) diff --git a/noxfile.py b/noxfile.py index f6a2291fc..8576fed85 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,7 +28,7 @@ import nox FLAKE8_VERSION = "flake8==6.1.0" -BLACK_VERSION = "black[jupyter]==23.7.0" +BLACK_VERSION = "black[jupyter]==23.3.0" ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] @@ -49,6 +49,8 @@ "pytest", "pytest-cov", "pytest-asyncio", + BLACK_VERSION, + "autoflake", ] UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [] UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] @@ -64,7 +66,7 @@ ] SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", - "black==23.7.0", + BLACK_VERSION, "pyyaml==6.0.2", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] @@ -561,3 +563,13 @@ def prerelease_deps(session, protobuf_implementation): "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, }, ) + + +@nox.session(python="3.10") +def generate_sync(session): + """ + Re-generate sync files for the library from CrossSync-annotated async source + """ + session.install(BLACK_VERSION) + session.install("autoflake") + session.run("python", ".cross_sync/generate.py", ".") diff --git a/test_proxy/handlers/client_handler_data_async.py b/test_proxy/handlers/client_handler_data_async.py index 7f6cc413f..49539c1aa 100644 --- a/test_proxy/handlers/client_handler_data_async.py +++ b/test_proxy/handlers/client_handler_data_async.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index c0e9f39d2..b97859de1 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py new file mode 100644 index 000000000..2dde82bf1 --- /dev/null +++ b/tests/system/data/test_system_autogen.py @@ -0,0 +1,828 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +import pytest +import uuid +import os +from google.api_core import retry +from google.api_core.exceptions import ClientError +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data._cross_sync import CrossSync +from . import TEST_FAMILY, TEST_FAMILY_2 + + +@CrossSync._Sync_Impl.add_mapping_decorator("TempRowBuilder") +class TempRowBuilder: + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + self.table.client._gapic_client.mutate_rows(request) + + +class TestSystem: + @pytest.fixture(scope="session") + def client(self): + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + with CrossSync._Sync_Impl.DataClient(project=project) as client: + yield client + + @pytest.fixture(scope="session") + def table(self, client, table_id, instance_id): + with client.get_table(instance_id, table_id) as table: + yield table + + @pytest.fixture(scope="session") + def column_family_config(self): + """specify column families to create when creating a new test table""" + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + @pytest.fixture(scope="session") + def init_table_id(self): + """The table_id to use when creating a new test table""" + return f"test-table-{uuid.uuid4().hex}" + + @pytest.fixture(scope="session") + def cluster_config(self, project_id): + """Configuration for the clusters to use when creating a new instance""" + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", serve_nodes=1 + ) + } + return cluster + + @pytest.mark.usefixtures("table") + def _retrieve_cell_value(self, table, row_key): + """Helper to read an individual row""" + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + def _create_row_and_mutation( + self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" + ): + """Helper to create a new row, and a sample set_cell mutation to change its value""" + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + assert self._retrieve_cell_value(table, row_key) == start_value + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return (row_key, mutation) + + @pytest.fixture(scope="function") + def temp_rows(self, table): + builder = CrossSync._Sync_Impl.TempRowBuilder(table) + yield builder + builder.delete_rows() + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 + ) + def test_ping_and_warm_gapic(self, client, table): + """Simple ping rpc test + This test ensures channels are able to authenticate with backend""" + request = {"name": table.instance_name} + client._gapic_client.ping_and_warm(request) + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_ping_and_warm(self, client, table): + """Test ping and warm from handwritten client""" + results = client._ping_and_warm_instances() + assert len(results) == 1 + assert results[0] is None + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutation_set_cell(self, table, temp_rows): + """Ensure cells can be set properly""" + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + table.mutate_row(row_key, mutation) + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + """Sample keys should return a single sample in small test tables""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + results = table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_bulk_mutations_set_cell(self, client, table, temp_rows): + """Ensure cells can be set properly""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + table.bulk_mutate_rows([bulk_mutation]) + assert self._retrieve_cell_value(table, row_key) == new_value + + def test_bulk_mutations_raise_exception(self, client, table): + """If an invalid mutation is passed, an exception should be raised""" + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell( + family="nonexistent", qualifier=b"test-qualifier", new_value=b"" + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + with pytest.raises(MutationsExceptionGroup) as exc: + table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_context_manager(self, client, table, temp_rows): + """test batcher with context manager. Should flush on exit""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher() as batcher: + batcher.append(bulk_mutation) + batcher.append(bulk_mutation2) + assert self._retrieve_cell_value(table, row_key) == new_value + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + """batch should occur after flush_interval seconds""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + with table.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher.append(bulk_mutation) + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + CrossSync._Sync_Impl.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_count_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_mutation_count mutations""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + for future in list(batcher._flush_jobs): + future + future.result() + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_bytes bytes""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + for future in list(batcher._flush_jobs): + future + future.result() + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_mutations_batcher_no_flush(self, client, table, temp_rows): + """test with no flush requirements met""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 0 + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == start_value + assert self._retrieve_cell_value(table, row_key2) == start_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_large_batch(self, client, table, temp_rows): + """test batcher with large batch of mutations""" + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + + add_mutation = SetCell( + family=TEST_FAMILY, qualifier=b"test-qualifier", new_value=b"a" + ) + row_mutations = [] + for i in range(50000): + row_key = uuid.uuid4().hex.encode() + row_mutations.append(RowMutationEntry(row_key, [add_mutation])) + temp_rows.rows.append(row_key) + with table.mutations_batcher() as batcher: + for mutation in row_mutations: + batcher.append(mutation) + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], + ) + def test_read_modify_write_row_increment( + self, client, table, temp_rows, start, increment, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = IncrementRule(family, qualifier, increment) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], + ) + def test_read_modify_write_row_append( + self, client, table, temp_rows, start, append, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = AppendValueRule(family, qualifier, append) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_read_modify_write_row_chained(self, client, table, temp_rows): + """test read_modify_write_row with multiple rules""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [(1, (0, 2), True), (-1, (0, 2), False)], + ) + def test_check_and_mutate( + self, client, table, temp_rows, start_val, predicate_range, expected_result + ): + """test that check_and_mutate_row works applies the right mutations, and returns the right result""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start_val, family=family, qualifier=qualifier) + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + expected_value = ( + true_mutation_value if expected_result else false_mutation_value + ) + assert self._retrieve_cell_value(table, row_key) == expected_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_check_and_mutate_empty_request(self, client, table): + """check_and_mutate with no true or fale mutations should raise an error""" + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_stream(self, table, temp_rows): + """Ensure that the read_rows_stream method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + generator = table.read_rows_stream({}) + first_row = generator.__next__() + second_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows(self, table, temp_rows): + """Ensure that the read_rows method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + row_list = table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_simple(self, table, temp_rows): + """Test read rows sharded with two queries""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_from_sample(self, table, temp_rows): + """Test end-to-end sharding""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + table_shard_keys = table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_filters_limits(self, table, temp_rows): + """Test read rows sharded with filters and limits""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_range_query(self, table, temp_rows): + """Ensure that the read_rows method works""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_single_key_query(self, table, temp_rows): + """Ensure that the read_rows method works with specified query""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_with_filter(self, table, temp_rows): + """ensure filters are applied""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.usefixtures("table") + def test_read_rows_stream_close(self, table, temp_rows): + """Ensure that the read_rows_stream can be closed""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + query = ReadRowsQuery() + generator = table.read_rows_stream(query) + first_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + generator.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + def test_read_row(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + + temp_rows.add_row(b"row_key_1", value=b"value") + row = table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_read_row_missing(self, table): + """Test read_row when row does not exist""" + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + table.read_row("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + def test_read_row_w_filter(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_row_exists(self, table, temp_rows): + from google.api_core import exceptions + + "Test row_exists with rows that exist and don't exist" + assert table.row_exists(b"row_key_1") is False + temp_rows.add_row(b"row_key_1") + assert table.row_exists(b"row_key_1") is True + assert table.row_exists("row_key_1") is True + assert table.row_exists(b"row_key_2") is False + assert table.row_exists("row_key_2") is False + assert table.row_exists("3") is False + temp_rows.add_row(b"3") + assert table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + ("\\a", "\\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + ("\\C☃", "\\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], + ) + def test_literal_value_filter( + self, table, temp_rows, cell_value, filter_input, expect_match + ): + """Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server""" + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 621f4d9a2..13f668fd3 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 6a4583a7b..944681a84 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -1,3 +1,4 @@ +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index c24fa3d98..8d829a363 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1279,7 +1279,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ # we expect an exception from attempting to call the mock pass assert rpc_mock.call_count == 1 - kwargs = rpc_mock.call_args_list[0].kwargs + kwargs = rpc_mock.call_args_list[0][1] metadata = kwargs["metadata"] # expect single metadata entry assert len(metadata) == 1 @@ -1906,7 +1906,7 @@ async def mock_call(*args, **kwargs): assert read_rows.call_count == 10 assert len(result) == 10 # if run in sequence, we would expect this to take 1 second - assert call_time < 0.2 + assert call_time < 0.5 @CrossSync.pytest async def test_read_rows_sharded_concurrency_limit(self): @@ -2005,21 +2005,25 @@ async def test_read_rows_sharded_negative_batch_timeout(self): They should raise DeadlineExceeded errors """ from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT from google.api_core.exceptions import DeadlineExceeded async def mock_call(*args, **kwargs): - await CrossSync.sleep(0.05) + await CrossSync.sleep(0.06) return [mock.Mock()] async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(15)] + num_calls = 15 + queries = [ReadRowsQuery() for _ in range(num_calls)] with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - await table.read_rows_sharded(queries, operation_timeout=0.01) + await table.read_rows_sharded(queries, operation_timeout=0.05) assert isinstance(exc.value, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == 5 + # _CONCURRENCY_LIMIT calls will run, and won't be interrupted + # calls after the limit will be cancelled due to timeout + assert len(exc.value.exceptions) >= num_calls - _CONCURRENCY_LIMIT assert all( isinstance(e.__cause__, DeadlineExceeded) for e in exc.value.exceptions diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index cd442d392..2df8dde6d 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 45d139182..ab9502223 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/data/_sync_autogen/__init__.py b/tests/unit/data/_sync_autogen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_sync_autogen/test__mutate_rows.py b/tests/unit/data/_sync_autogen/test__mutate_rows.py new file mode 100644 index 000000000..2173c88fb --- /dev/null +++ b/tests/unit/data/_sync_autogen/test__mutate_rows.py @@ -0,0 +1,307 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +import pytest +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.rpc import status_pb2 +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import Forbidden +from google.cloud.bigtable.data._cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock + + +class TestMutateRowsOperation: + def _target_class(self): + return CrossSync._Sync_Impl._MutateRowsOperation + + def _make_one(self, *args, **kwargs): + if not args: + kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) + kwargs["table"] = kwargs.pop("table", CrossSync._Sync_Impl.Mock()) + kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) + kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) + kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) + return self._target_class()(*args, **kwargs) + + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def _mock_stream(self, mutation_list, error_dict): + for idx, entry in enumerate(mutation_list): + code = error_dict.get(idx, 0) + yield MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=code) + ) + ] + ) + + def _make_mock_gapic(self, mutation_list, error_dict=None): + mock_fn = CrossSync._Sync_Impl.Mock() + if error_dict is None: + error_dict = {} + mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( + mutation_list, error_dict + ) + return mock_fn + + def test_ctor(self): + """test that constructor sets all the attributes correctly""" + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import Aborted + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + attempt_timeout = 0.01 + retryable_exceptions = () + instance = self._make_one( + client, + table, + entries, + operation_timeout, + attempt_timeout, + retryable_exceptions, + ) + assert client.mutate_rows.call_count == 0 + instance._gapic_fn() + assert client.mutate_rows.call_count == 1 + inner_kwargs = client.mutate_rows.call_args[1] + assert len(inner_kwargs) == 3 + assert inner_kwargs["table_name"] == table.table_name + assert inner_kwargs["app_profile_id"] == table.app_profile_id + assert inner_kwargs["retry"] is None + entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] + assert instance.mutations == entries_w_pb + assert next(instance.timeout_generator) == attempt_timeout + assert instance.is_retryable is not None + assert instance.is_retryable(DeadlineExceeded("")) is False + assert instance.is_retryable(Aborted("")) is False + assert instance.is_retryable(_MutateRowsIncomplete("")) is True + assert instance.is_retryable(RuntimeError("")) is False + assert instance.remaining_indices == list(range(len(entries))) + assert instance.errors == {} + + def test_ctor_too_many_entries(self): + """should raise an error if an operation is created with more than 100,000 entries""" + from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + ) + + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) + operation_timeout = 0.05 + attempt_timeout = 0.01 + with pytest.raises(ValueError) as e: + self._make_one(client, table, entries, operation_timeout, attempt_timeout) + assert "mutate_rows requests can contain at most 100000 mutations" in str( + e.value + ) + assert "Found 100001" in str(e.value) + + def test_mutate_rows_operation(self): + """Test successful case of mutate_rows_operation""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + cls = self._target_class() + with mock.patch( + f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + assert attempt_mock.call_count == 1 + + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + def test_mutate_rows_attempt_exception(self, exc_type): + """exceptions raised from attempt should be raised in MutationsExceptionGroup""" + client = CrossSync._Sync_Impl.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_exception = exc_type("test") + client.mutate_rows.side_effect = expected_exception + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance._run_attempt() + except Exception as e: + found_exc = e + assert client.mutate_rows.call_count == 1 + assert type(found_exc) is exc_type + assert found_exc == expected_exception + assert len(instance.errors) == 2 + assert len(instance.remaining_indices) == 0 + + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + def test_mutate_rows_exception(self, exc_type): + """exceptions raised from retryable should be raised in MutationsExceptionGroup""" + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_cause = exc_type("abort") + with mock.patch.object( + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + attempt_mock.side_effect = expected_cause + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count == 1 + assert len(found_exc.exceptions) == 2 + assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) + assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) + assert found_exc.exceptions[0].__cause__ == expected_cause + assert found_exc.exceptions[1].__cause__ == expected_cause + + @pytest.mark.parametrize("exc_type", [DeadlineExceeded, RuntimeError]) + def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): + """If an exception fails but eventually passes, it should not raise an exception""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 1 + expected_cause = exc_type("retry") + num_retries = 2 + with mock.patch.object( + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + attempt_mock.side_effect = [expected_cause] * num_retries + [None] + instance = self._make_one( + client, + table, + entries, + operation_timeout, + operation_timeout, + retryable_exceptions=(exc_type,), + ) + instance.start() + assert attempt_mock.call_count == num_retries + 1 + + def test_mutate_rows_incomplete_ignored(self): + """MutateRowsIncomplete exceptions should not be added to error list""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 0.05 + with mock.patch.object( + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + attempt_mock.side_effect = _MutateRowsIncomplete("ignored") + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count > 0 + assert len(found_exc.exceptions) == 1 + assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) + + def test_run_attempt_single_entry_success(self): + """Test mutating a single entry""" + mutation = self._make_mutation() + expected_timeout = 1.3 + mock_gapic_fn = self._make_mock_gapic({0: mutation}) + instance = self._make_one( + mutation_entries=[mutation], attempt_timeout=expected_timeout + ) + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert len(instance.remaining_indices) == 0 + assert mock_gapic_fn.call_count == 1 + (_, kwargs) = mock_gapic_fn.call_args + assert kwargs["timeout"] == expected_timeout + assert kwargs["entries"] == [mutation._to_pb()] + + def test_run_attempt_empty_request(self): + """Calling with no mutations should result in no API calls""" + mock_gapic_fn = self._make_mock_gapic([]) + instance = self._make_one(mutation_entries=[]) + instance._run_attempt() + assert mock_gapic_fn.call_count == 0 + + def test_run_attempt_partial_success_retryable(self): + """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: True + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + with pytest.raises(_MutateRowsIncomplete): + instance._run_attempt() + assert instance.remaining_indices == [1] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors + + def test_run_attempt_partial_success_non_retryable(self): + """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: False + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert instance.remaining_indices == [] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors diff --git a/tests/unit/data/_sync_autogen/test__read_rows.py b/tests/unit/data/_sync_autogen/test__read_rows.py new file mode 100644 index 000000000..973b07bcb --- /dev/null +++ b/tests/unit/data/_sync_autogen/test__read_rows.py @@ -0,0 +1,354 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +import pytest +from google.cloud.bigtable.data._cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock + + +class TestReadRowsOperation: + """ + Tests helper functions in the ReadRowsOperation class + in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt + is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests + """ + + @staticmethod + def _get_target_class(): + return CrossSync._Sync_Impl._ReadRowsOperation + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + from google.cloud.bigtable.data import ReadRowsQuery + + row_limit = 91 + query = ReadRowsQuery(limit=row_limit) + client = mock.Mock() + client.read_rows = mock.Mock() + client.read_rows.return_value = None + table = mock.Mock() + table._client = client + table.table_name = "test_table" + table.app_profile_id = "test_profile" + expected_operation_timeout = 42 + expected_request_timeout = 44 + time_gen_mock = mock.Mock() + subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync_autogen" + with mock.patch( + f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", + time_gen_mock, + ): + instance = self._make_one( + query, + table, + operation_timeout=expected_operation_timeout, + attempt_timeout=expected_request_timeout, + ) + assert time_gen_mock.call_count == 1 + time_gen_mock.assert_called_once_with( + expected_request_timeout, expected_operation_timeout + ) + assert instance._last_yielded_row_key is None + assert instance._remaining_count == row_limit + assert instance.operation_timeout == expected_operation_timeout + assert client.read_rows.call_count == 0 + assert instance.request.table_name == table.table_name + assert instance.request.app_profile_id == table.app_profile_id + assert instance.request.rows_limit == row_limit + + @pytest.mark.parametrize( + "in_keys,last_key,expected", + [ + (["b", "c", "d"], "a", ["b", "c", "d"]), + (["a", "b", "c"], "b", ["c"]), + (["a", "b", "c"], "c", []), + (["a", "b", "c"], "d", []), + (["d", "c", "b", "a"], "b", ["d", "c"]), + ], + ) + @pytest.mark.parametrize("with_range", [True, False]) + def test_revise_request_rowset_keys_with_range( + self, in_keys, last_key, expected, with_range + ): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + from google.cloud.bigtable.data.exceptions import _RowSetComplete + + in_keys = [key.encode("utf-8") for key in in_keys] + expected = [key.encode("utf-8") for key in expected] + last_key = last_key.encode("utf-8") + if with_range: + sample_range = [RowRangePB(start_key_open=last_key)] + else: + sample_range = [] + row_set = RowSetPB(row_keys=in_keys, row_ranges=sample_range) + if not with_range and expected == []: + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, last_key) + else: + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == expected + assert revised.row_ranges == sample_range + + @pytest.mark.parametrize( + "in_ranges,last_key,expected", + [ + ( + [{"start_key_open": "b", "end_key_closed": "d"}], + "a", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "a", + [{"start_key_closed": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_open": "a", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "a", "end_key_open": "d"}], + "b", + [{"start_key_open": "b", "end_key_open": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), + ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), + ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), + ( + [{"end_key_closed": "z"}], + "a", + [{"start_key_open": "a", "end_key_closed": "z"}], + ), + ( + [{"end_key_open": "z"}], + "a", + [{"start_key_open": "a", "end_key_open": "z"}], + ), + ], + ) + @pytest.mark.parametrize("with_key", [True, False]) + def test_revise_request_rowset_ranges( + self, in_ranges, last_key, expected, with_key + ): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + from google.cloud.bigtable.data.exceptions import _RowSetComplete + + next_key = (last_key + "a").encode("utf-8") + last_key = last_key.encode("utf-8") + in_ranges = [ + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) + for r in in_ranges + ] + expected = [ + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) + for r in expected + ] + if with_key: + row_keys = [next_key] + else: + row_keys = [] + row_set = RowSetPB(row_ranges=in_ranges, row_keys=row_keys) + if not with_key and expected == []: + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, last_key) + else: + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == row_keys + assert revised.row_ranges == expected + + @pytest.mark.parametrize("last_key", ["a", "b", "c"]) + def test_revise_request_full_table(self, last_key): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + last_key = last_key.encode("utf-8") + row_set = RowSetPB() + for selected_set in [row_set, None]: + revised = self._get_target_class()._revise_request_rowset( + selected_set, last_key + ) + assert revised.row_keys == [] + assert len(revised.row_ranges) == 1 + assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) + + def test_revise_to_empty_rowset(self): + """revising to an empty rowset should raise error""" + from google.cloud.bigtable.data.exceptions import _RowSetComplete + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + row_keys = [b"a", b"b", b"c"] + row_range = RowRangePB(end_key_open=b"c") + row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, b"d") + + @pytest.mark.parametrize( + "start_limit,emit_num,expected_limit", + [ + (10, 0, 10), + (10, 1, 9), + (10, 10, 0), + (None, 10, None), + (None, 0, None), + (4, 2, 2), + ], + ) + def test_revise_limit(self, start_limit, emit_num, expected_limit): + """revise_limit should revise the request's limit field + - if limit is 0 (unlimited), it should never be revised + - if start_limit-emit_num == 0, the request should end early + - if the number emitted exceeds the new limit, an exception should + should be raised (tested in test_revise_limit_over_limit)""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + for val in instance.chunk_stream(awaitable_stream()): + pass + assert instance._remaining_count == expected_limit + + @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) + def test_revise_limit_over_limit(self, start_limit, emit_num): + """Should raise runtime error if we get in state where emit_num > start_num + (unless start_num == 0, which represents unlimited)""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + from google.cloud.bigtable.data.exceptions import InvalidChunk + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + with pytest.raises(InvalidChunk) as e: + for val in instance.chunk_stream(awaitable_stream()): + pass + assert "emit count exceeds row limit" in str(e.value) + + def test_close(self): + """should be able to close a stream safely with close. + Closed generators should raise StopAsyncIteration on next yield""" + + def mock_stream(): + while True: + yield 1 + + with mock.patch.object( + self._get_target_class(), "_read_rows_attempt" + ) as mock_attempt: + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) + wrapped_gen = mock_stream() + mock_attempt.return_value = wrapped_gen + gen = instance.start_operation() + gen.__next__() + gen.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + gen.__next__() + gen.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + wrapped_gen.__next__() + + def test_retryable_ignore_repeated_rows(self): + """Duplicate rows should cause an invalid chunk error""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import ReadRowsResponse + + row_key = b"duplicate" + + def mock_awaitable_stream(): + def mock_stream(): + while True: + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + + return mock_stream() + + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) + stream.__next__() + with pytest.raises(InvalidChunk) as exc: + stream.__next__() + assert "row keys should be strictly increasing" in str(exc.value) diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py new file mode 100644 index 000000000..51c88c63e --- /dev/null +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -0,0 +1,2889 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +import grpc +import asyncio +import re +import pytest +import mock +from google.cloud.bigtable.data import mutations +from google.auth.credentials import AnonymousCredentials +from google.cloud.bigtable_v2.types import ReadRowsResponse +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.api_core import exceptions as core_exceptions +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule +from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse +from google.cloud.bigtable.data._cross_sync import CrossSync +from google.api_core import grpc_helpers + +CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) + + +@CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") +class TestBigtableDataClient: + @staticmethod + def _get_target_class(): + return CrossSync._Sync_Impl.DataClient + + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os + + env_mask = {} + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings + + warnings.filterwarnings("ignore", category=RuntimeWarning) + else: + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) + + def test_ctor(self): + expected_project = "project-id" + expected_credentials = AnonymousCredentials() + client = self._make_client( + project="project-id", credentials=expected_credentials, use_emulator=False + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert client.project == expected_project + assert not client._active_instances + assert client._channel_refresh_task is not None + assert client.transport._credentials == expected_credentials + client.close() + + def test_ctor_super_inits(self): + from google.cloud.client import ClientWithProject + from google.api_core import client_options as client_options_lib + + project = "project-id" + credentials = AnonymousCredentials() + client_options = {"api_endpoint": "foo.bar:1234"} + options_parsed = client_options_lib.from_dict(client_options) + with mock.patch.object( + CrossSync._Sync_Impl.GapicClient, "__init__" + ) as bigtable_client_init: + bigtable_client_init.return_value = None + with mock.patch.object( + ClientWithProject, "__init__" + ) as client_project_init: + client_project_init.return_value = None + try: + self._make_client( + project=project, + credentials=credentials, + client_options=options_parsed, + use_emulator=False, + ) + except AttributeError: + pass + assert bigtable_client_init.call_count == 1 + kwargs = bigtable_client_init.call_args[1] + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + assert client_project_init.call_count == 1 + kwargs = client_project_init.call_args[1] + assert kwargs["project"] == project + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + + def test_ctor_dict_options(self): + from google.api_core.client_options import ClientOptions + + client_options = {"api_endpoint": "foo.bar:1234"} + with mock.patch.object( + CrossSync._Sync_Impl.GapicClient, "__init__" + ) as bigtable_client_init: + try: + self._make_client(client_options=client_options) + except TypeError: + pass + bigtable_client_init.assert_called_once() + kwargs = bigtable_client_init.call_args[1] + called_options = kwargs["client_options"] + assert called_options.api_endpoint == "foo.bar:1234" + assert isinstance(called_options, ClientOptions) + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ) as start_background_refresh: + client = self._make_client( + client_options=client_options, use_emulator=False + ) + start_background_refresh.assert_called_once() + client.close() + + def test_veneer_grpc_headers(self): + client_component = "data-async" if CrossSync._Sync_Impl.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" + + client_component + + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" + ) + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") + with patch as gapic_mock: + client = self._make_client(project="project-id") + wrapped_call_list = gapic_mock.call_args_list + assert len(wrapped_call_list) > 0 + for call in wrapped_call_list: + client_info = call.kwargs["client_info"] + assert client_info is not None, f"{call} has no client_info" + wrapped_user_agent_sorted = " ".join( + sorted(client_info.to_user_agent().split(" ")) + ) + assert VENEER_HEADER_REGEX.match( + wrapped_user_agent_sorted + ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" + client.close() + + def test__start_background_channel_refresh_task_exists(self): + client = self._make_client(project="project-id", use_emulator=False) + assert client._channel_refresh_task is not None + with mock.patch.object(asyncio, "create_task") as create_task: + client._start_background_channel_refresh() + create_task.assert_not_called() + client.close() + + def test__start_background_channel_refresh(self): + client = self._make_client(project="project-id") + with mock.patch.object( + client, "_ping_and_warm_instances", CrossSync._Sync_Impl.Mock() + ) as ping_and_warm: + client._emulator_host = None + client._start_background_channel_refresh() + assert client._channel_refresh_task is not None + assert isinstance(client._channel_refresh_task, CrossSync._Sync_Impl.Task) + CrossSync._Sync_Impl.sleep(0.1) + assert ping_and_warm.call_count == 1 + client.close() + + def test__ping_and_warm_instances(self): + """test ping and warm with mocked asyncio.gather""" + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() + ) as gather: + gather.side_effect = lambda partials, **kwargs: [None for _ in partials] + channel = mock.Mock() + client_mock._active_instances = [] + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel=channel + ) + assert len(result) == 0 + assert gather.call_args[1]["return_exceptions"] is True + assert gather.call_args[1]["sync_executor"] == client_mock._executor + client_mock._active_instances = [ + (mock.Mock(), mock.Mock(), mock.Mock()) + ] * 4 + gather.reset_mock() + channel.reset_mock() + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel=channel + ) + assert len(result) == 4 + gather.assert_called_once() + partial_list = gather.call_args.args[0] + assert len(partial_list) == 4 + grpc_call_args = channel.unary_unary().call_args_list + for idx, (_, kwargs) in enumerate(grpc_call_args): + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] + request = kwargs["request"] + assert request["name"] == expected_instance + assert request["app_profile_id"] == expected_app_profile + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] + == f"name={expected_instance}&app_profile_id={expected_app_profile}" + ) + + def test__ping_and_warm_single_instance(self): + """should be able to call ping and warm with single instance""" + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() + ) as gather: + gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] + client_mock._active_instances = [mock.Mock()] * 100 + test_key = ("test-instance", "test-table", "test-app-profile") + result = self._get_target_class()._ping_and_warm_instances( + client_mock, test_key + ) + assert len(result) == 1 + grpc_call_args = ( + client_mock.transport.grpc_channel.unary_unary().call_args_list + ) + assert len(grpc_call_args) == 1 + kwargs = grpc_call_args[0][1] + request = kwargs["request"] + assert request["name"] == "test-instance" + assert request["app_profile_id"] == "test-app-profile" + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" + ) + + @pytest.mark.parametrize( + "refresh_interval, wait_time, expected_sleep", + [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], + ) + def test__manage_channel_first_sleep( + self, refresh_interval, wait_time, expected_sleep + ): + import time + + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: + sleep.side_effect = asyncio.CancelledError + try: + client = self._make_client(project="project-id") + client._channel_init_time = -wait_time + client._manage_channel(refresh_interval, refresh_interval) + except asyncio.CancelledError: + pass + sleep.assert_called_once() + call_time = sleep.call_args[0][1] + assert ( + abs(call_time - expected_sleep) < 0.1 + ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" + client.close() + + def test__manage_channel_ping_and_warm(self): + """_manage channel should call ping and warm internally""" + import time + import threading + + client_mock = mock.Mock() + client_mock._is_closed.is_set.return_value = False + client_mock._channel_init_time = time.monotonic() + orig_channel = client_mock.transport.grpc_channel + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple): + orig_channel.close.side_effect = asyncio.CancelledError + ping_and_warm = ( + client_mock._ping_and_warm_instances + ) = CrossSync._Sync_Impl.Mock() + try: + self._get_target_class()._manage_channel(client_mock, 10) + except asyncio.CancelledError: + pass + assert ping_and_warm.call_count == 2 + assert client_mock.transport._grpc_channel != orig_channel + called_with = [call[1]["channel"] for call in ping_and_warm.call_args_list] + assert orig_channel in called_with + assert client_mock.transport.grpc_channel in called_with + + @pytest.mark.parametrize( + "refresh_interval, num_cycles, expected_sleep", + [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], + ) + def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): + import time + import random + + channel = mock.Mock() + channel.close = CrossSync._Sync_Impl.Mock() + with mock.patch.object(random, "uniform") as uniform: + uniform.side_effect = lambda min_, max_: min_ + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: + sleep.side_effect = [None for i in range(num_cycles - 1)] + [ + asyncio.CancelledError + ] + client = self._make_client(project="project-id") + client.transport._grpc_channel = channel + with mock.patch.object( + client.transport, "create_channel", CrossSync._Sync_Impl.Mock + ): + try: + if refresh_interval is not None: + client._manage_channel( + refresh_interval, refresh_interval, grace_period=0 + ) + else: + client._manage_channel(grace_period=0) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + total_sleep = sum([call[0][1] for call in sleep.call_args_list]) + assert ( + abs(total_sleep - expected_sleep) < 0.1 + ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" + client.close() + + def test__manage_channel_random(self): + import random + + with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: + with mock.patch.object(random, "uniform") as uniform: + uniform.return_value = 0 + try: + uniform.side_effect = asyncio.CancelledError + client = self._make_client(project="project-id") + except asyncio.CancelledError: + uniform.side_effect = None + uniform.reset_mock() + sleep.reset_mock() + with mock.patch.object(client.transport, "create_channel"): + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, asyncio.CancelledError] + try: + client._manage_channel(min_val, max_val, grace_period=0) + except asyncio.CancelledError: + pass + assert uniform.call_count == 2 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val + + @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) + def test__manage_channel_refresh(self, num_cycles): + expected_refresh = 0.5 + grpc_lib = grpc.aio if CrossSync._Sync_Impl.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") + with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [RuntimeError] + with mock.patch.object( + CrossSync._Sync_Impl.grpc_helpers, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + client = self._make_client(project="project-id") + create_channel.reset_mock() + try: + client._manage_channel( + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=0, + ) + except RuntimeError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + client.close() + + def test__register_instance(self): + """test instance registration""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_task = None + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() + table_mock = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + client_mock._channel_refresh_task = mock.Mock() + table_mock2 = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-2", table_mock2 + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + assert ( + client_mock._ping_and_warm_instances.call_args[0][0][0] + == "prefix/instance-2" + ) + assert client_mock._ping_and_warm_instances.call_count == 1 + assert len(active_instances) == 2 + assert len(instance_owners) == 2 + expected_key2 = ( + "prefix/instance-2", + table_mock2.table_name, + table_mock2.app_profile_id, + ) + assert any( + [ + expected_key2 == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + assert any( + [ + expected_key2 == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + def test__register_instance_duplicate(self): + """test double instance registration. Should be no-op""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_task = object() + mock_channels = [mock.Mock()] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() + table_mock = mock.Mock() + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._ping_and_warm_instances.call_count == 1 + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._ping_and_warm_instances.call_count == 1 + + @pytest.mark.parametrize( + "insert_instances,expected_active,expected_owner_keys", + [ + ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), + ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), + ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ( + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + ), + ], + ) + def test__register_instance_state( + self, insert_instances, expected_active, expected_owner_keys + ): + """test that active_instances and instance_owners are updated as expected""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: b + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_task = None + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() + table_mock = mock.Mock() + for instance, table, profile in insert_instances: + table_mock.table_name = table + table_mock.app_profile_id = profile + self._get_target_class()._register_instance( + client_mock, instance, table_mock + ) + assert len(active_instances) == len(expected_active) + assert len(instance_owners) == len(expected_owner_keys) + for expected in expected_active: + assert any( + [ + expected == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + for expected in expected_owner_keys: + assert any( + [ + expected == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + def test__remove_instance_registration(self): + client = self._make_client(project="project-id") + table = mock.Mock() + client._register_instance("instance-1", table) + client._register_instance("instance-2", table) + assert len(client._active_instances) == 2 + assert len(client._instance_owners.keys()) == 2 + instance_1_path = client._gapic_client.instance_path( + client.project, "instance-1" + ) + instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance-2" + ) + instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + assert len(client._instance_owners[instance_1_key]) == 1 + assert list(client._instance_owners[instance_1_key])[0] == id(table) + assert len(client._instance_owners[instance_2_key]) == 1 + assert list(client._instance_owners[instance_2_key])[0] == id(table) + success = client._remove_instance_registration("instance-1", table) + assert success + assert len(client._active_instances) == 1 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 1 + assert client._active_instances == {instance_2_key} + success = client._remove_instance_registration("fake-key", table) + assert not success + assert len(client._active_instances) == 1 + client.close() + + def test__multiple_table_registration(self): + """registering with multiple tables with the same key should + add multiple owners to instance_owners, but only keep one copy + of shared key in active_instances""" + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_1") as table_2: + assert table_2._register_instance_future is not None + table_2._register_instance_future.result() + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_3") as table_3: + assert table_3._register_instance_future is not None + table_3._register_instance_future.result() + instance_3_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_3_key = _WarmedInstanceKey( + instance_3_path, table_3.table_name, table_3.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._instance_owners[instance_3_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + assert id(table_3) in client._instance_owners[instance_3_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert id(table_2) not in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert instance_1_key not in client._active_instances + assert len(client._instance_owners[instance_1_key]) == 0 + + def test__multiple_instance_registration(self): + """registering with multiple instance keys should update the key + in instance_owners and active_instances""" + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + assert table_1._register_instance_future is not None + table_1._register_instance_future.result() + with client.get_table("instance_2", "table_2") as table_2: + assert table_2._register_instance_future is not None + table_2._register_instance_future.result() + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance_2" + ) + instance_2_key = _WarmedInstanceKey( + instance_2_path, table_2.table_name, table_2.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._instance_owners[instance_2_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_2_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert len(client._instance_owners[instance_2_key]) == 0 + assert len(client._instance_owners[instance_1_key]) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 0 + + def test_get_table(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + client = self._make_client(project="project-id") + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + table = client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert isinstance(table, CrossSync._Sync_Impl.TestTable._get_target_class()) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + client.close() + + def test_get_table_arg_passthrough(self): + """All arguments passed in get_table should be sent to constructor""" + with self._make_client(project="project-id") as client: + with mock.patch.object( + CrossSync._Sync_Impl.TestTable._get_target_class(), "__init__" + ) as mock_constructor: + mock_constructor.return_value = None + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_args = (1, "test", {"test": 2}) + expected_kwargs = {"hello": "world", "test": 2} + client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + mock_constructor.assert_called_once_with( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + + def test_get_table_context_manager(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_project_id = "project-id" + with mock.patch.object( + CrossSync._Sync_Impl.TestTable._get_target_class(), "close" + ) as close_mock: + with self._make_client(project=expected_project_id) as client: + with client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) as table: + CrossSync._Sync_Impl.yield_to_event_loop() + assert isinstance( + table, CrossSync._Sync_Impl.TestTable._get_target_class() + ) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert close_mock.call_count == 1 + + def test_close(self): + client = self._make_client(project="project-id", use_emulator=False) + task = client._channel_refresh_task + assert task is not None + assert not task.done() + with mock.patch.object( + client.transport, "close", CrossSync._Sync_Impl.Mock() + ) as close_mock: + client.close() + close_mock.assert_called_once() + assert task.done() + assert client._channel_refresh_task is None + + def test_close_with_timeout(self): + expected_timeout = 19 + client = self._make_client(project="project-id", use_emulator=False) + with mock.patch.object( + CrossSync._Sync_Impl, "wait", CrossSync._Sync_Impl.Mock() + ) as wait_for_mock: + client.close(timeout=expected_timeout) + wait_for_mock.assert_called_once() + assert wait_for_mock.call_args[1]["timeout"] == expected_timeout + client.close() + + def test_context_manager(self): + from functools import partial + + close_mock = CrossSync._Sync_Impl.Mock() + true_close = None + with self._make_client(project="project-id", use_emulator=False) as client: + true_close = partial(client.close) + client.close = close_mock + assert not client._channel_refresh_task.done() + assert client.project == "project-id" + assert client._active_instances == set() + close_mock.assert_not_called() + close_mock.assert_called_once() + true_close() + + +@CrossSync._Sync_Impl.add_mapping_decorator("TestTable") +class TestTable: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @staticmethod + def _get_target_class(): + return CrossSync._Sync_Impl.Table + + def test_table_ctor(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert table.default_operation_timeout == expected_operation_timeout + assert table.default_attempt_timeout == expected_attempt_timeout + assert ( + table.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + table.default_read_rows_attempt_timeout + == expected_read_rows_attempt_timeout + ) + assert ( + table.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + table.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None + client.close() + + def test_table_ctor_defaults(self): + """should provide default timeout values and app_profile_id""" + expected_table_id = "table-id" + expected_instance_id = "instance-id" + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, expected_instance_id, expected_table_id + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id is None + assert table.client is client + assert table.default_operation_timeout == 60 + assert table.default_read_rows_operation_timeout == 600 + assert table.default_mutate_rows_operation_timeout == 600 + assert table.default_attempt_timeout == 20 + assert table.default_read_rows_attempt_timeout == 20 + assert table.default_mutate_rows_attempt_timeout == 60 + client.close() + + def test_table_ctor_invalid_timeout_values(self): + """bad timeout values should raise ValueError""" + client = self._make_client() + timeout_pairs = [ + ("default_operation_timeout", "default_attempt_timeout"), + ( + "default_read_rows_operation_timeout", + "default_read_rows_attempt_timeout", + ), + ( + "default_mutate_rows_operation_timeout", + "default_mutate_rows_attempt_timeout", + ), + ] + for operation_timeout, attempt_timeout in timeout_pairs: + with pytest.raises(ValueError) as e: + self._get_target_class()(client, "", "", **{attempt_timeout: -1}) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._get_target_class()(client, "", "", **{operation_timeout: -1}) + assert "operation_timeout must be greater than 0" in str(e.value) + client.close() + + @pytest.mark.parametrize( + "fn_name,fn_args,is_stream,extra_retryables", + [ + ("read_rows_stream", (ReadRowsQuery(),), True, ()), + ("read_rows", (ReadRowsQuery(),), True, ()), + ("read_row", (b"row_key",), True, ()), + ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), + ("row_exists", (b"row_key",), True, ()), + ("sample_row_keys", (), False, ()), + ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, + (_MutateRowsIncomplete,), + ), + ], + ) + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + is_stream, + extra_retryables, + ): + """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer.""" + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + with mock.patch( + f"google.cloud.bigtable.data._cross_sync.{retry_fn}" + ) as retry_fn_mock: + with self._make_client() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = expected_retryables.__contains__ + retry_fn_mock.side_effect = RuntimeError("stop early") + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + predicate_builder_mock.return_value = expected_predicate + with pytest.raises(Exception): + test_fn = table.__getattribute__(fn_name) + test_fn(*fn_args, retryable_errors=input_retryables) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, *extra_retryables + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate + + @pytest.mark.parametrize( + "fn_name,fn_args,gapic_fn", + [ + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", [mutations.DeleteAllFromRow()]), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", IncrementRule("f", "q")), + "read_modify_write_row", + ), + ], + ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + profile = "profile" if include_app_profile else None + client = self._make_client() + transport_mock = mock.MagicMock() + rpc_mock = CrossSync._Sync_Impl.Mock() + transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock + gapic_client = client._gapic_client + gapic_client._transport = transport_mock + gapic_client._is_universe_domain_valid = True + table = self._get_target_class()(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = test_fn(*fn_args) + [i for i in maybe_stream] + except Exception: + pass + assert rpc_mock.call_count == 1 + kwargs = rpc_mock.call_args_list[0][1] + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + routing_str = metadata[0][1] + assert "table_name=" + table.table_name in routing_str + if include_app_profile: + assert "app_profile_id=profile" in routing_str + else: + assert "app_profile_id=" not in routing_str + + +@CrossSync._Sync_Impl.add_mapping_decorator("TestReadRows") +class TestReadRows: + """ + Tests for table.read_rows and related methods. + """ + + @staticmethod + def _get_operation_class(): + return CrossSync._Sync_Impl._ReadRowsOperation + + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_table(self, *args, **kwargs): + client_mock = mock.Mock() + client_mock._register_instance.side_effect = ( + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() + ) + client_mock._remove_instance_registration.side_effect = ( + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() + ) + kwargs["instance_id"] = kwargs.get( + "instance_id", args[0] if args else "instance" + ) + kwargs["table_id"] = kwargs.get( + "table_id", args[1] if len(args) > 1 else "table" + ) + client_mock._gapic_client.table_path.return_value = kwargs["table_id"] + client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + return CrossSync._Sync_Impl.TestTable._get_target_class()( + client_mock, *args, **kwargs + ) + + def _make_stats(self): + from google.cloud.bigtable_v2.types import RequestStats + from google.cloud.bigtable_v2.types import FullReadStatsView + from google.cloud.bigtable_v2.types import ReadIterationStats + + return RequestStats( + full_read_stats_view=FullReadStatsView( + read_iteration_stats=ReadIterationStats( + rows_seen_count=1, + rows_returned_count=2, + cells_seen_count=3, + cells_returned_count=4, + ) + ) + ) + + @staticmethod + def _make_chunk(*args, **kwargs): + from google.cloud.bigtable_v2 import ReadRowsResponse + + kwargs["row_key"] = kwargs.get("row_key", b"row_key") + kwargs["family_name"] = kwargs.get("family_name", "family_name") + kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") + kwargs["value"] = kwargs.get("value", b"value") + kwargs["commit_row"] = kwargs.get("commit_row", True) + return ReadRowsResponse.CellChunk(*args, **kwargs) + + @staticmethod + def _make_gapic_stream( + chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 + ): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list, sleep_time): + self.chunk_list = chunk_list + self.idx = -1 + self.sleep_time = sleep_time + + def __iter__(self): + return self + + def __next__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + if sleep_time: + CrossSync._Sync_Impl.sleep(self.sleep_time) + chunk = self.chunk_list[self.idx] + if isinstance(chunk, Exception): + raise chunk + else: + return ReadRowsResponse(chunks=[chunk]) + raise CrossSync._Sync_Impl.StopIteration + + def cancel(self): + pass + + return mock_stream(chunk_list, sleep_time) + + def execute_fn(self, table, *args, **kwargs): + return table.read_rows(*args, **kwargs) + + def test_read_rows(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + results = self.execute_fn(table, query, operation_timeout=3) + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + def test_read_rows_stream(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + gen = table.read_rows_stream(query, operation_timeout=3) + results = [row for row in gen] + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_read_rows_query_matches_request(self, include_app_profile): + from google.cloud.bigtable.data import RowRange + from google.cloud.bigtable.data.row_filters import PassAllFilter + + app_profile_id = "app_profile_id" if include_app_profile else None + with self._make_table(app_profile_id=app_profile_id) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) + row_keys = [b"test_1", "test_2"] + row_ranges = RowRange("1start", "2end") + filter_ = PassAllFilter(True) + limit = 99 + query = ReadRowsQuery( + row_keys=row_keys, + row_ranges=row_ranges, + row_filter=filter_, + limit=limit, + ) + results = table.read_rows(query, operation_timeout=3) + assert len(results) == 0 + call_request = read_rows.call_args_list[0][0][0] + query_pb = query._to_pb(table) + assert call_request == query_pb + + @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) + def test_read_rows_timeout(self, operation_timeout): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + query = ReadRowsQuery() + chunks = [self._make_chunk(row_key=b"test_1")] + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=0.15 + ) + try: + table.read_rows(query, operation_timeout=operation_timeout) + except core_exceptions.DeadlineExceeded as e: + assert ( + e.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" + ) + + @pytest.mark.parametrize( + "per_request_t, operation_t, expected_num", + [(0.05, 0.08, 2), (0.05, 0.14, 3), (0.05, 0.24, 5)], + ) + def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): + """Ensures that the attempt_timeout is respected and that the number of + requests is as expected. + + operation_timeout does not cancel the request, so we expect the number of + requests to be the ceiling of operation_timeout / attempt_timeout.""" + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + expected_last_timeout = operation_t - (expected_num - 1) * per_request_t + with mock.patch("random.uniform", side_effect=lambda a, b: 0): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t + ) + query = ReadRowsQuery() + chunks = [core_exceptions.DeadlineExceeded("mock deadline")] + try: + table.read_rows( + query, + operation_timeout=operation_t, + attempt_timeout=per_request_t, + ) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" + assert read_rows.call_count == expected_num + for _, call_kwargs in read_rows.call_args_list[:-1]: + assert call_kwargs["timeout"] == per_request_t + assert call_kwargs["retry"] is None + assert ( + abs( + read_rows.call_args_list[-1][1]["timeout"] + - expected_last_timeout + ) + < 0.05 + ) + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Aborted, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + def test_read_rows_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Cancelled, + core_exceptions.PreconditionFailed, + core_exceptions.NotFound, + core_exceptions.PermissionDenied, + core_exceptions.Conflict, + core_exceptions.InternalServerError, + core_exceptions.TooManyRequests, + core_exceptions.ResourceExhausted, + InvalidChunk, + ], + ) + def test_read_rows_non_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except exc_type as e: + assert e == expected_error + + def test_read_rows_revise_request(self): + """Ensure that _revise_request is called between retries""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import RowSet + + return_val = RowSet() + with mock.patch.object( + self._get_operation_class(), "_revise_request_rowset" + ) as revise_rowset: + revise_rowset.return_value = return_val + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + row_keys = [b"test_1", b"test_2", b"test_3"] + query = ReadRowsQuery(row_keys=row_keys) + chunks = [ + self._make_chunk(row_key=b"test_1"), + core_exceptions.Aborted("mock retryable error"), + ] + try: + table.read_rows(query) + except InvalidChunk: + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val + + def test_read_rows_default_timeouts(self): + """Ensure that the default timeouts are set on the read rows operation when not overridden""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_read_rows_operation_timeout=operation_timeout, + default_read_rows_attempt_timeout=attempt_timeout, + ) as table: + try: + table.read_rows(ReadRowsQuery()) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_rows_default_timeout_override(self): + """When timeouts are passed, they overwrite default values""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_operation_timeout=99, default_attempt_timeout=97 + ) as table: + try: + table.read_rows( + ReadRowsQuery(), + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + ) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_row(self): + """Test reading a single row""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert row == expected_result + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + def test_read_row_w_filter(self): + """Test reading a single row with an added filter""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + mock_filter = mock.Mock() + expected_filter = {"filter": "mock filter"} + mock_filter._to_dict.return_value = expected_filter + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + row_filter=expected_filter, + ) + assert row == expected_result + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter == expected_filter + + def test_read_row_no_response(self): + """should return None if row does not exist""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: [] + expected_op_timeout = 8 + expected_req_timeout = 4 + result = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert result is None + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + @pytest.mark.parametrize( + "return_value,expected_result", + [([], False), ([object()], True), ([object(), object()], True)], + ) + def test_row_exists(self, return_value, expected_result): + """Test checking for row existence""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: return_value + expected_op_timeout = 1 + expected_req_timeout = 2 + result = table.row_exists( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert expected_result == result + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + expected_filter = { + "chain": { + "filters": [ + {"cells_per_row_limit_filter": 1}, + {"strip_value_transformer": True}, + ] + } + } + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter._to_dict() == expected_filter + + +class TestReadRowsSharded: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def test_read_rows_sharded_empty_query(self): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as exc: + table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) + + def test_read_rows_sharded_multiple_queries(self): + """Test with multiple queries. Should return results from both""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.side_effect = lambda *args, **kwargs: CrossSync._Sync_Impl.TestReadRows._make_gapic_stream( + [ + CrossSync._Sync_Impl.TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] + ) + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + result = table.read_rows_sharded([query_1, query_2]) + assert len(result) == 2 + assert result[0].row_key == b"test_1" + assert result[1].row_key == b"test_2" + + @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) + def test_read_rows_sharded_multiple_queries_calls(self, n_queries): + """Each query should trigger a separate read_rows call""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + query_list = [ReadRowsQuery() for _ in range(n_queries)] + table.read_rows_sharded(query_list) + assert read_rows.call_count == n_queries + + def test_read_rows_sharded_errors(self): + """Errors should be exposed as ShardedReadRowsExceptionGroups""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedQueryShardError + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = RuntimeError("mock error") + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded([query_1, query_2]) + exc_group = exc.value + assert isinstance(exc_group, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 2 + assert isinstance(exc.value.exceptions[0], FailedQueryShardError) + assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) + assert exc.value.exceptions[0].index == 0 + assert exc.value.exceptions[0].query == query_1 + assert isinstance(exc.value.exceptions[1], FailedQueryShardError) + assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) + assert exc.value.exceptions[1].index == 1 + assert exc.value.exceptions[1].query == query_2 + + def test_read_rows_sharded_concurrent(self): + """Ensure sharded requests are concurrent""" + import time + + def mock_call(*args, **kwargs): + CrossSync._Sync_Impl.sleep(0.1) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(10)] + start_time = time.monotonic() + result = table.read_rows_sharded(queries) + call_time = time.monotonic() - start_time + assert read_rows.call_count == 10 + assert len(result) == 10 + assert call_time < 0.5 + + def test_read_rows_sharded_concurrency_limit(self): + """Only 10 queries should be processed concurrently. Others should be queued + + Should start a new query as soon as previous finishes""" + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + + assert _CONCURRENCY_LIMIT == 10 + num_queries = 15 + increment_time = 0.05 + max_time = increment_time * (_CONCURRENCY_LIMIT - 1) + rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)] + + def mock_call(*args, **kwargs): + next_sleep = rpc_times.pop(0) + asyncio.sleep(next_sleep) + return [mock.Mock()] + + starting_timeout = 10 + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + table.read_rows_sharded(queries, operation_timeout=starting_timeout) + assert read_rows.call_count == num_queries + rpc_start_list = [ + starting_timeout - kwargs["operation_timeout"] + for (_, kwargs) in read_rows.call_args_list + ] + eps = 0.01 + assert all( + (rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)) + ) + for i in range(num_queries - _CONCURRENCY_LIMIT): + idx = i + _CONCURRENCY_LIMIT + assert rpc_start_list[idx] - i * increment_time < eps + + def test_read_rows_sharded_expirary(self): + """If the operation times out before all shards complete, should raise + a ShardedReadRowsExceptionGroup""" + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + operation_timeout = 0.1 + num_queries = 15 + sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * ( + num_queries - _CONCURRENCY_LIMIT + ) + + def mock_call(*args, **kwargs): + next_item = sleeps.pop(0) + if isinstance(next_item, Exception): + raise next_item + else: + asyncio.sleep(next_item) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded( + queries, operation_timeout=operation_timeout + ) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT + assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT + + def test_read_rows_sharded_negative_batch_timeout(self): + """try to run with batch that starts after operation timeout + + They should raise DeadlineExceeded errors""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + from google.api_core.exceptions import DeadlineExceeded + + def mock_call(*args, **kwargs): + CrossSync._Sync_Impl.sleep(0.06) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + num_calls = 15 + queries = [ReadRowsQuery() for _ in range(num_calls)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded(queries, operation_timeout=0.05) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) >= num_calls - _CONCURRENCY_LIMIT + assert all( + ( + isinstance(e.__cause__, DeadlineExceeded) + for e in exc.value.exceptions + ) + ) + + +class TestSampleRowKeys: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + from google.cloud.bigtable_v2.types import SampleRowKeysResponse + + for value in sample_list: + yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) + + def test_sample_row_keys(self): + """Test that method returns the expected key samples""" + samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = table.sample_row_keys() + assert len(result) == 3 + assert all((isinstance(r, tuple) for r in result)) + assert all((isinstance(r[0], bytes) for r in result)) + assert all((isinstance(r[1], int) for r in result)) + assert result[0] == samples[0] + assert result[1] == samples[1] + assert result[2] == samples[2] + + def test_sample_row_keys_bad_timeout(self): + """should raise error if timeout is negative""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.sample_row_keys(operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + table.sample_row_keys(attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_sample_row_keys_default_timeout(self): + """Should fallback to using table default operation_timeout""" + expected_timeout = 99 + with self._make_client() as client: + with client.get_table( + "i", + "t", + default_operation_timeout=expected_timeout, + default_attempt_timeout=expected_timeout, + ) as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + result = table.sample_row_keys() + (_, kwargs) = sample_row_keys.call_args + assert abs(kwargs["timeout"] - expected_timeout) < 0.1 + assert result == [] + assert kwargs["retry"] is None + + def test_sample_row_keys_gapic_params(self): + """make sure arguments are propagated to gapic call as expected""" + expected_timeout = 10 + expected_profile = "test1" + instance = "instance_name" + table_id = "my_table" + with self._make_client() as client: + with client.get_table( + instance, table_id, app_profile_id=expected_profile + ) as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + table.sample_row_keys(attempt_timeout=expected_timeout) + (args, kwargs) = sample_row_keys.call_args + assert len(args) == 0 + assert len(kwargs) == 4 + assert kwargs["timeout"] == expected_timeout + assert kwargs["app_profile_id"] == expected_profile + assert kwargs["table_name"] == table.table_name + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_sample_row_keys_retryable_errors(self, retryable_exception): + """retryable errors should be retried until timeout""" + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + table.sample_row_keys(operation_timeout=0.05) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) > 0 + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): + """non-retryable errors should cause a raise""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + table.sample_row_keys() + + +class TestMutateRow: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize( + "mutation_arg", + [ + mutations.SetCell("family", b"qualifier", b"value"), + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ), + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromFamily("family"), + mutations.DeleteAllFromRow(), + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_mutate_row(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.return_value = None + table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_retryable_errors(self, retryable_exception): + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.01) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): + """Non-idempotent mutations should not be retried""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + assert mutation.is_idempotent() is False + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + def test_mutate_row_non_retryable_errors(self, non_retryable_exception): + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize("mutations", [[], None]) + def test_mutate_row_no_mutations(self, mutations): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.mutate_row("key", mutations=mutations) + assert e.value.args[0] == "No mutations provided" + + +class TestBulkMutateRows: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _mock_response(self, response_list): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + statuses = [] + for response in response_list: + if isinstance(response, core_exceptions.GoogleAPICallError): + statuses.append( + status_pb2.Status( + message=str(response), code=response.grpc_status_code.value[0] + ) + ) + else: + statuses.append(status_pb2.Status(code=0)) + entries = [ + MutateRowsResponse.Entry(index=i, status=statuses[i]) + for i in range(len(response_list)) + ] + + def generator(): + yield MutateRowsResponse(entries=entries) + + return generator() + + @pytest.mark.parametrize( + "mutation_arg", + [ + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ) + ], + [mutations.DeleteRangeFromColumn("family", b"qualifier")], + [mutations.DeleteAllFromFamily("family")], + [mutations.DeleteAllFromRow()], + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_bulk_mutate_rows(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + table.bulk_mutate_rows( + [bulk_mutation], attempt_timeout=expected_attempt_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + def test_bulk_mutate_rows_multiple_entries(self): + """Test mutations with no errors""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None, None]) + mutation_list = [mutations.DeleteAllFromRow()] + entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) + entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) + table.bulk_mutate_rows([entry_1, entry_2]) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"][0] == entry_1._to_pb() + assert kwargs["entries"][1] == entry_2._to_pb() + + @pytest.mark.parametrize( + "exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): + """Individual idempotent mutations should be retried if they fail with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], exception) + assert isinstance( + cause.exceptions[-1], core_exceptions.DeadlineExceeded + ) + + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + core_exceptions.Aborted, + ], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): + """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): + """Individual idempotent mutations should be retried if the request fails with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_non_idempotent_retryable_errors( + self, retryable_exception + ): + """Non-Idempotent mutations should never be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + ], + ) + def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): + """If the request fails with a non-retryable error, mutations should not be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) + + def test_bulk_mutate_error_index(self): + """Test partial failure, partial success. Errors should be associated with the correct index""" + from google.api_core.exceptions import ( + DeadlineExceeded, + ServiceUnavailable, + FailedPrecondition, + ) + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([None, ServiceUnavailable("mock"), None]), + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([FailedPrecondition("final")]), + ] + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry( + f"row_key_{i}".encode(), [mutation] + ) + for i in range(3) + ] + assert mutation.is_idempotent() is True + table.bulk_mutate_rows(entries, operation_timeout=1000) + assert len(e.value.exceptions) == 1 + failed = e.value.exceptions[0] + assert isinstance(failed, FailedMutationEntryError) + assert failed.index == 1 + assert failed.entry == entries[1] + cause = failed.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) == 3 + assert isinstance(cause.exceptions[0], ServiceUnavailable) + assert isinstance(cause.exceptions[1], DeadlineExceeded) + assert isinstance(cause.exceptions[2], FailedPrecondition) + + def test_bulk_mutate_error_recovery(self): + """If an error occurs, then resolves, no exception should be raised""" + from google.api_core.exceptions import DeadlineExceeded + + with self._make_client(project="project") as client: + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([None]), + ] + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) + for i in range(3) + ] + table.bulk_mutate_rows(entries, operation_timeout=1000) + + +class TestCheckAndMutateRow: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize("gapic_result", [True, False]) + def test_check_and_mutate(self, gapic_result): + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + app_profile = "app_profile_id" + with self._make_client() as client: + with client.get_table( + "instance", "table", app_profile_id=app_profile + ) as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=gapic_result + ) + row_key = b"row_key" + predicate = None + true_mutations = [mock.Mock()] + false_mutations = [mock.Mock(), mock.Mock()] + operation_timeout = 0.2 + found = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + operation_timeout=operation_timeout, + ) + assert found == gapic_result + kwargs = mock_gapic.call_args[1] + assert kwargs["table_name"] == table.table_name + assert kwargs["row_key"] == row_key + assert kwargs["predicate_filter"] == predicate + assert kwargs["true_mutations"] == [ + m._to_pb() for m in true_mutations + ] + assert kwargs["false_mutations"] == [ + m._to_pb() for m in false_mutations + ] + assert kwargs["app_profile_id"] == app_profile + assert kwargs["timeout"] == operation_timeout + assert kwargs["retry"] is None + + def test_check_and_mutate_bad_timeout(self): + """Should raise error if operation_timeout < 0""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=[mock.Mock()], + false_case_mutations=[], + operation_timeout=-1, + ) + assert str(e.value) == "operation_timeout must be greater than 0" + + def test_check_and_mutate_single_mutations(self): + """if single mutations are passed, they should be internally wrapped in a list""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + true_mutation = SetCell("family", b"qualifier", b"value") + false_mutation = SetCell("family", b"qualifier", b"value") + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == [true_mutation._to_pb()] + assert kwargs["false_mutations"] == [false_mutation._to_pb()] + + def test_check_and_mutate_predicate_object(self): + """predicate filter should be passed to gapic request""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + mock_predicate = mock.Mock() + predicate_pb = {"predicate": "dict"} + mock_predicate._to_pb.return_value = predicate_pb + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["predicate_filter"] == predicate_pb + assert mock_predicate._to_pb.call_count == 1 + assert kwargs["retry"] is None + + def test_check_and_mutate_mutations_parsing(self): + """mutations objects should be converted to protos""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + mutations = [mock.Mock() for _ in range(5)] + for idx, mutation in enumerate(mutations): + mutation._to_pb.return_value = f"fake {idx}" + mutations.append(DeleteAllFromRow()) + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=mutations[0:2], + false_case_mutations=mutations[2:], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == ["fake 0", "fake 1"] + assert kwargs["false_mutations"] == [ + "fake 2", + "fake 3", + "fake 4", + DeleteAllFromRow()._to_pb(), + ] + assert all( + (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) + ) + + +class TestReadModifyWriteRow: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize( + "call_rules,expected_rules", + [ + ( + AppendValueRule("f", "c", b"1"), + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + ( + [AppendValueRule("f", "c", b"1")], + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), + ( + [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], + [ + AppendValueRule("f", "c", b"1")._to_pb(), + IncrementRule("f", "c", 1)._to_pb(), + ], + ), + ], + ) + def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): + """Test that the gapic call is called with given rules""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None + + @pytest.mark.parametrize("rules", [[], None]) + def test_read_modify_write_no_rules(self, rules): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.read_modify_write_row("key", rules=rules) + assert e.value.args[0] == "rules must contain at least one item" + + def test_read_modify_write_call_defaults(self): + instance = "instance1" + table_id = "table1" + project = "project1" + row_key = "row_key1" + with self._make_client(project=project) as client: + with client.get_table(instance, table_id) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert ( + kwargs["table_name"] + == f"projects/{project}/instances/{instance}/tables/{table_id}" + ) + assert kwargs["app_profile_id"] is None + assert kwargs["row_key"] == row_key.encode() + assert kwargs["timeout"] > 1 + + def test_read_modify_write_call_overrides(self): + row_key = b"row_key1" + expected_timeout = 12345 + profile_id = "profile1" + with self._make_client() as client: + with client.get_table( + "instance", "table_id", app_profile_id=profile_id + ) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row( + row_key, mock.Mock(), operation_timeout=expected_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["app_profile_id"] is profile_id + assert kwargs["row_key"] == row_key + assert kwargs["timeout"] == expected_timeout + + def test_read_modify_write_string_key(self): + row_key = "string_row_key1" + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["row_key"] == row_key.encode() + + def test_read_modify_write_row_building(self): + """results from gapic call should be used to construct row""" + from google.cloud.bigtable.data.row import Row + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + from google.cloud.bigtable_v2.types import Row as RowPB + + mock_response = ReadModifyWriteRowResponse(row=RowPB()) + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + with mock.patch.object(Row, "_from_pb") as constructor_mock: + mock_gapic.return_value = mock_response + table.read_modify_write_row("key", mock.Mock()) + assert constructor_mock.call_count == 1 + constructor_mock.assert_called_once_with(mock_response.row) + + +class TestExecuteQuery: + TABLE_NAME = "TABLE_NAME" + INSTANCE_NAME = "INSTANCE_NAME" + + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): + class MockStream: + def __init__(self, sample_list): + self.sample_list = sample_list + + def __aiter__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + if not self.sample_list: + raise CrossSync._Sync_Impl.StopIteration + value = self.sample_list.pop(0) + if isinstance(value, Exception): + raise value + return value + + def __anext__(self): + return self.__next__() + + return MockStream(sample_list) + + def resonse_with_metadata(self): + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + schema = {"a": "string_type", "b": "int64_type"} + return ExecuteQueryResponse( + { + "metadata": { + "proto_schema": { + "columns": [ + {"name": name, "type_": {_type: {}}} + for (name, _type) in schema.items() + ] + } + } + } + ) + + def resonse_with_result(self, *args, resume_token=None): + from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + if resume_token is None: + resume_token_dict = {} + else: + resume_token_dict = {"resume_token": resume_token} + values = [] + for column_value in args: + if column_value is None: + pb_value = PBValue({}) + else: + pb_value = PBValue( + { + "int_value" + if isinstance(column_value, int) + else "string_value": column_value + } + ) + values.append(pb_value) + rows = ProtoRows(values=values) + return ExecuteQueryResponse( + { + "results": { + "proto_rows_batch": {"batch_data": ProtoRows.serialize(rows)}, + **resume_token_dict, + } + } + ) + + def test_execute_query(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert execute_query_mock.call_count == 1 + + def test_execute_query_with_params(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME} WHERE b=@b", + self.INSTANCE_NAME, + parameters={"b": 9}, + ) + results = [r for r in result] + assert len(results) == 1 + assert results[0]["a"] == "test2" + assert results[0]["b"] == 9 + assert execute_query_mock.call_count == 1 + + def test_execute_query_error_before_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + + def test_execute_query_error_after_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + def test_execute_query_with_retries(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + DeadlineExceeded(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + DeadlineExceeded(""), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert len(results) == 3 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"r1", b"r2"] + + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.DeadlineExceeded(""), + core_exceptions.Aborted(""), + core_exceptions.ServiceUnavailable(""), + ], + ) + def test_execute_query_retryable_error(self, exception): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + exception, + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert len(results) == 1 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + def test_execute_query_retry_partial_row(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + core_exceptions.DeadlineExceeded(""), + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + @pytest.mark.parametrize( + "ExceptionType", + [ + core_exceptions.InvalidArgument, + core_exceptions.FailedPrecondition, + core_exceptions.PermissionDenied, + core_exceptions.MethodNotImplemented, + core_exceptions.Cancelled, + core_exceptions.AlreadyExists, + core_exceptions.OutOfRange, + core_exceptions.DataLoss, + core_exceptions.Unauthenticated, + core_exceptions.NotFound, + core_exceptions.ResourceExhausted, + core_exceptions.Unknown, + core_exceptions.InternalServerError, + ], + ) + def test_execute_query_non_retryable(self, ExceptionType): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + ExceptionType(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + r = CrossSync._Sync_Impl.next(result) + assert r["a"] == "test" + assert r["b"] == 8 + with pytest.raises(ExceptionType): + r = CrossSync._Sync_Impl.next(result) + assert execute_query_mock.call_count == 1 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + def test_execute_query_metadata_received_multiple_times_detected(self): + values = [self.resonse_with_metadata(), self.resonse_with_metadata()] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + with pytest.raises( + Exception, match="Invalid ExecuteQuery response received" + ): + [ + r + for r in client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + ] diff --git a/tests/unit/data/_sync_autogen/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py new file mode 100644 index 000000000..59ea621ac --- /dev/null +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -0,0 +1,1078 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +import pytest +import mock +import asyncio +import time +import google.api_core.exceptions as core_exceptions +import google.api_core.retry +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data._cross_sync import CrossSync + + +class Test_FlowControl: + @staticmethod + def _target_class(): + return CrossSync._Sync_Impl._FlowControl + + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + return self._target_class()(max_mutation_count, max_mutation_bytes) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor(self): + max_mutation_count = 9 + max_mutation_bytes = 19 + instance = self._make_one(max_mutation_count, max_mutation_bytes) + assert instance._max_mutation_count == max_mutation_count + assert instance._max_mutation_bytes == max_mutation_bytes + assert instance._in_flight_mutation_count == 0 + assert instance._in_flight_mutation_bytes == 0 + assert isinstance(instance._capacity_condition, CrossSync._Sync_Impl.Condition) + + def test_ctor_invalid_values(self): + """Test that values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(0, 1) + assert "max_mutation_count must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(1, 0) + assert "max_mutation_bytes must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + [ + (1, 1, 0, 0, 0, 0, True), + (1, 1, 1, 1, 1, 1, False), + (10, 10, 0, 0, 0, 0, True), + (10, 10, 0, 0, 9, 9, True), + (10, 10, 0, 0, 11, 9, True), + (10, 10, 0, 1, 11, 9, True), + (10, 10, 1, 0, 11, 9, False), + (10, 10, 0, 0, 9, 11, True), + (10, 10, 1, 0, 9, 11, True), + (10, 10, 0, 1, 9, 11, False), + (10, 1, 0, 0, 1, 0, True), + (1, 10, 0, 0, 0, 8, True), + (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), + (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), + (12, 12, 6, 6, 5, 5, True), + (12, 12, 5, 5, 6, 6, True), + (12, 12, 6, 6, 6, 6, True), + (12, 12, 6, 6, 7, 7, False), + (12, 12, 0, 0, 13, 13, True), + (12, 12, 12, 0, 0, 13, True), + (12, 12, 0, 12, 13, 0, True), + (12, 12, 1, 1, 13, 13, False), + (12, 12, 1, 1, 0, 13, False), + (12, 12, 1, 1, 13, 0, False), + ], + ) + def test__has_capacity( + self, + max_count, + max_size, + existing_count, + existing_size, + new_count, + new_size, + expected, + ): + """_has_capacity should return True if the new mutation will will not exceed the max count or size""" + instance = self._make_one(max_count, max_size) + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + assert instance._has_capacity(new_count, new_size) == expected + + @pytest.mark.parametrize( + "existing_count,existing_size,added_count,added_size,new_count,new_size", + [ + (0, 0, 0, 0, 0, 0), + (2, 2, 1, 1, 1, 1), + (2, 0, 1, 0, 1, 0), + (0, 2, 0, 1, 0, 1), + (10, 10, 0, 0, 10, 10), + (10, 10, 5, 5, 5, 5), + (0, 0, 1, 1, -1, -1), + ], + ) + def test_remove_from_flow_value_update( + self, + existing_count, + existing_size, + added_count, + added_size, + new_count, + new_size, + ): + """completed mutations should lower the inflight values""" + instance = self._make_one() + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + mutation = self._make_mutation(added_count, added_size) + instance.remove_from_flow(mutation) + assert instance._in_flight_mutation_count == new_count + assert instance._in_flight_mutation_bytes == new_size + + def test__remove_from_flow_unlock(self): + """capacity condition should notify after mutation is complete""" + instance = self._make_one(10, 10) + instance._in_flight_mutation_count = 10 + instance._in_flight_mutation_bytes = 10 + + def task_routine(): + with instance._capacity_condition: + instance._capacity_condition.wait_for( + lambda: instance._has_capacity(1, 1) + ) + + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + CrossSync._Sync_Impl.sleep(0.05) + assert task_alive() is True + mutation = self._make_mutation(count=0, size=5) + instance.remove_from_flow([mutation]) + CrossSync._Sync_Impl.sleep(0.05) + assert instance._in_flight_mutation_count == 10 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is True + instance._in_flight_mutation_bytes = 10 + mutation = self._make_mutation(count=5, size=0) + instance.remove_from_flow([mutation]) + CrossSync._Sync_Impl.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 10 + assert task_alive() is True + instance._in_flight_mutation_count = 10 + mutation = self._make_mutation(count=5, size=5) + instance.remove_from_flow([mutation]) + CrossSync._Sync_Impl.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is False + + @pytest.mark.parametrize( + "mutations,count_cap,size_cap,expected_results", + [ + ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), + ( + [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], + 5, + 5, + [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], + ), + ], + ) + def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): + """Test batching with various flow control settings""" + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + instance = self._make_one(count_cap, size_cap) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.parametrize( + "mutations,max_limit,expected_results", + [ + ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), + ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), + ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ], + ) + def test_add_to_flow_max_mutation_limits( + self, mutations, max_limit, expected_results + ): + """Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more""" + subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync_autogen" + path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" + with mock.patch(path, max_limit): + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] + instance = self._make_one(float("inf"), float("inf")) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + def test_add_to_flow_oversize(self): + """mutations over the flow control limits should still be accepted""" + instance = self._make_one(2, 3) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) + results = [out for out in instance.add_to_flow([large_size_mutation])] + assert len(results) == 1 + instance.remove_from_flow(results[0]) + count_results = [out for out in instance.add_to_flow(large_count_mutation)] + assert len(count_results) == 1 + + +class TestMutationsBatcher: + def _get_target_class(self): + return CrossSync._Sync_Impl.MutationsBatcher + + def _make_one(self, table=None, **kwargs): + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import ServiceUnavailable + + if table is None: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 10 + table.default_mutate_rows_retryable_errors = ( + DeadlineExceeded, + ServiceUnavailable, + ) + return self._get_target_class()(table, **kwargs) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor_defaults(self): + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors + == table.default_mutate_rows_retryable_errors + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) + + def test_ctor_explicit(self): + """Test with explicit parameters""" + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), + ) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + CrossSync._Sync_Impl.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) + + def test_ctor_no_flush_limits(self): + """Test with None for flush limits""" + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + CrossSync._Sync_Impl.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) + + def test_ctor_invalid_values(self): + """Test that timeout values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(batch_operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(batch_attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_default_argument_consistency(self): + """We supply default arguments in MutationsBatcherAsync.__init__, and in + table.mutations_batcher. Make sure any changes to defaults are applied to + both places""" + import inspect + + get_batcher_signature = dict( + inspect.signature(CrossSync._Sync_Impl.Table.mutations_batcher).parameters + ) + get_batcher_signature.pop("self") + batcher_init_signature = dict( + inspect.signature(self._get_target_class()).parameters + ) + batcher_init_signature.pop("table") + assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) + assert len(get_batcher_signature) == 8 + assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) + for arg_name in get_batcher_signature.keys(): + assert ( + get_batcher_signature[arg_name].default + == batcher_init_signature[arg_name].default + ) + + @pytest.mark.parametrize("input_val", [None, 0, -1]) + def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = instance._timer_routine(input_val) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + assert result is None + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__start_flush_timer_call_when_closed(self): + """closed batcher's timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + instance.close() + flush_mock.reset_mock() + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + instance._timer_routine(10) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__flush_timer(self, num_staged): + """Timer should continue to call _schedule_flush in a loop""" + from google.cloud.bigtable.data._cross_sync import CrossSync + + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + instance._staged_entries = [mock.Mock()] * num_staged + with mock.patch.object( + CrossSync._Sync_Impl, "event_wait" + ) as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + with pytest.raises(TabError): + self._get_target_class()._timer_routine( + instance, expected_sleep + ) + assert sleep_mock.call_count == loop_num + 1 + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) + + def test__flush_timer_close(self): + """Timer should continue terminate after close""" + with mock.patch.object(self._get_target_class(), "_schedule_flush"): + with self._make_one() as instance: + assert instance._flush_timer.done() is False + instance.close() + assert instance._flush_timer.done() is True + + def test_append_closed(self): + """Should raise exception""" + instance = self._make_one() + instance.close() + with pytest.raises(RuntimeError): + instance.append(mock.Mock()) + + def test_append_wrong_mutation(self): + """Mutation objects should raise an exception. + Only support RowMutationEntry""" + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + with self._make_one() as instance: + expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" + with pytest.raises(ValueError) as e: + instance.append(DeleteAllFromRow()) + assert str(e.value) == expected_error + + def test_append_outside_flow_limits(self): + """entries larger than mutation limits are still processed""" + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + oversized_entry = self._make_mutation(count=0, size=2) + instance.append(oversized_entry) + assert instance._staged_entries == [oversized_entry] + assert instance._staged_count == 0 + assert instance._staged_bytes == 2 + instance._staged_entries = [] + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + overcount_entry = self._make_mutation(count=2, size=0) + instance.append(overcount_entry) + assert instance._staged_entries == [overcount_entry] + assert instance._staged_count == 2 + assert instance._staged_bytes == 0 + instance._staged_entries = [] + + def test_append_flush_runs_after_limit_hit(self): + """If the user appends a bunch of entries above the flush limits back-to-back, + it should still flush in a single task""" + with mock.patch.object( + self._get_target_class(), "_execute_mutate_rows" + ) as op_mock: + with self._make_one(flush_limit_bytes=100) as instance: + + def mock_call(*args, **kwargs): + return [] + + op_mock.side_effect = mock_call + instance.append(self._make_mutation(size=99)) + num_entries = 10 + for _ in range(num_entries): + instance.append(self._make_mutation(size=1)) + instance._wait_for_batch_results(*instance._flush_jobs) + assert op_mock.call_count == 1 + sent_batch = op_mock.call_args[0][0] + assert len(sent_batch) == 2 + assert len(instance._staged_entries) == num_entries - 1 + + @pytest.mark.parametrize( + "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", + [ + (10, 10, 1, 1, False), + (10, 10, 9, 9, False), + (10, 10, 10, 1, True), + (10, 10, 1, 10, True), + (10, 10, 10, 10, True), + (1, 1, 10, 10, True), + (1, 1, 0, 0, False), + ], + ) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_append( + self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush + ): + """test appending different mutations, and checking if it causes a flush""" + with self._make_one( + flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == bool(expect_flush) + assert instance._staged_count == mutation_count + assert instance._staged_bytes == mutation_bytes + assert instance._staged_entries == [mutation] + instance._staged_entries = [] + + def test_append_multiple_sequentially(self): + """Append multiple mutations""" + with self._make_one( + flush_limit_mutation_count=8, flush_limit_bytes=8 + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=2, size=3) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 2 + assert instance._staged_bytes == 3 + assert len(instance._staged_entries) == 1 + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 4 + assert instance._staged_bytes == 6 + assert len(instance._staged_entries) == 2 + instance.append(mutation) + assert flush_mock.call_count == 1 + assert instance._staged_count == 6 + assert instance._staged_bytes == 9 + assert len(instance._staged_entries) == 3 + instance._staged_entries = [] + + def test_flush_flow_control_concurrent_requests(self): + """requests should happen in parallel if flow control breaks up single flush into batches""" + import time + + num_calls = 10 + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] + with self._make_one(flow_control_max_mutation_count=1) as instance: + with mock.patch.object( + instance, "_execute_mutate_rows", CrossSync._Sync_Impl.Mock() + ) as op_mock: + + def mock_call(*args, **kwargs): + CrossSync._Sync_Impl.sleep(0.1) + return [] + + op_mock.side_effect = mock_call + start_time = time.monotonic() + instance._staged_entries = fake_mutations + instance._schedule_flush() + CrossSync._Sync_Impl.sleep(0.01) + for i in range(num_calls): + instance._flow_control.remove_from_flow( + [self._make_mutation(count=1)] + ) + CrossSync._Sync_Impl.sleep(0.01) + instance._wait_for_batch_results(*instance._flush_jobs) + duration = time.monotonic() - start_time + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert duration < 0.5 + assert op_mock.call_count == num_calls + + def test_schedule_flush_no_mutations(self): + """schedule flush should return None if no staged mutations""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(3): + assert instance._schedule_flush() is None + assert flush_mock.call_count == 0 + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_schedule_flush_with_mutations(self): + """if new mutations exist, should add a new flush task to _flush_jobs""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + flush_mock.side_effect = lambda x: time.sleep(0.1) + for i in range(1, 4): + mutation = mock.Mock() + instance._staged_entries = [mutation] + instance._schedule_flush() + assert instance._staged_entries == [] + asyncio.sleep(0) + assert instance._staged_entries == [] + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert flush_mock.call_count == 1 + flush_mock.reset_mock() + + def test__flush_internal(self): + """_flush_internal should: + - await previous flush call + - delegate batching to _flow_control + - call _execute_mutate_rows on each batch + - update self.exceptions and self._entries_processed_since_last_raise""" + num_entries = 10 + with self._make_one() as instance: + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def test_flush_clears_job_list(self): + """a job should be added to _flush_jobs when _schedule_flush is called, + and removed when it completes""" + with self._make_one() as instance: + with mock.patch.object( + instance, "_flush_internal", CrossSync._Sync_Impl.Mock() + ) as flush_mock: + flush_mock.side_effect = lambda x: time.sleep(0.1) + mutations = [self._make_mutation(count=1, size=1)] + instance._staged_entries = mutations + assert instance._flush_jobs == set() + new_job = instance._schedule_flush() + assert instance._flush_jobs == {new_job} + new_job.result() + assert instance._flush_jobs == set() + + @pytest.mark.parametrize( + "num_starting,num_new_errors,expected_total_errors", + [ + (0, 0, 0), + (0, 1, 1), + (0, 2, 2), + (1, 0, 1), + (1, 1, 2), + (10, 2, 12), + (10, 20, 20), + ], + ) + def test__flush_internal_with_errors( + self, num_starting, num_new_errors, expected_total_errors + ): + """errors returned from _execute_mutate_rows should be added to internal exceptions""" + from google.cloud.bigtable.data import exceptions + + num_entries = 10 + expected_errors = [ + exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) + ] * num_new_errors + with self._make_one() as instance: + instance._oldest_exceptions = [mock.Mock()] * num_starting + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + execute_mock.return_value = expected_errors + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + found_exceptions = instance._oldest_exceptions + list( + instance._newest_exceptions + ) + assert len(found_exceptions) == expected_total_errors + for i in range(num_starting, expected_total_errors): + assert found_exceptions[i] == expected_errors[i - num_starting] + assert found_exceptions[i].index is None + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def _mock_gapic_return(self, num=5): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + def gen(num): + for i in range(num): + entry = MutateRowsResponse.Entry( + index=i, status=status_pb2.Status(code=0) + ) + yield MutateRowsResponse(entries=[entry]) + + return gen(num) + + def test_timer_flush_end_to_end(self): + """Flush should automatically trigger after flush_interval""" + num_mutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_mutations + with self._make_one(flush_interval=0.05) as instance: + instance._table.default_operation_timeout = 10 + instance._table.default_attempt_timeout = 9 + with mock.patch.object( + instance._table.client._gapic_client, "mutate_rows" + ) as gapic_mock: + gapic_mock.side_effect = ( + lambda *args, **kwargs: self._mock_gapic_return(num_mutations) + ) + for m in mutations: + instance.append(m) + assert instance._entries_processed_since_last_raise == 0 + CrossSync._Sync_Impl.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_mutations + + def test__execute_mutate_rows(self): + with mock.patch.object( + CrossSync._Sync_Impl, "_MutateRowsOperation" + ) as mutate_rows: + mutate_rows.return_value = CrossSync._Sync_Impl.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + (args, kwargs) = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + def test__execute_mutate_rows_returns_errors(self): + """Errors from operation should be retruned as list""" + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + + with mock.patch.object( + CrossSync._Sync_Impl._MutateRowsOperation, "start" + ) as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + assert result[0].index is None + assert result[1].index is None + + def test__raise_exceptions(self): + """Raise exceptions and reset error state""" + from google.cloud.bigtable.data import exceptions + + expected_total = 1201 + expected_exceptions = [RuntimeError("mock")] * 3 + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance._raise_exceptions() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) + instance._raise_exceptions() + + def test___enter__(self): + """Should return self""" + with self._make_one() as instance: + assert instance.__enter__() == instance + + def test___exit__(self): + """aexit should call close""" + with self._make_one() as instance: + with mock.patch.object(instance, "close") as close_mock: + instance.__exit__(None, None, None) + assert close_mock.call_count == 1 + + def test_close(self): + """Should clean up all resources""" + with self._make_one() as instance: + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + with mock.patch.object(instance, "_raise_exceptions") as raise_mock: + instance.close() + assert instance.closed is True + assert instance._flush_timer.done() is True + assert instance._flush_jobs == set() + assert flush_mock.call_count == 1 + assert raise_mock.call_count == 1 + + def test_close_w_exceptions(self): + """Raise exceptions on close""" + from google.cloud.bigtable.data import exceptions + + expected_total = 10 + expected_exceptions = [RuntimeError("mock")] + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance.close() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) + + def test__on_exit(self, recwarn): + """Should raise warnings if unflushed mutations exist""" + with self._make_one() as instance: + instance._on_exit() + assert len(recwarn) == 0 + num_left = 4 + instance._staged_entries = [mock.Mock()] * num_left + with pytest.warns(UserWarning) as w: + instance._on_exit() + assert len(w) == 1 + assert "unflushed mutations" in str(w[0].message).lower() + assert str(num_left) in str(w[0].message) + instance._closed.set() + instance._on_exit() + assert len(recwarn) == 0 + instance._staged_entries = [] + + def test_atexit_registration(self): + """Should run _on_exit on program termination""" + import atexit + + with mock.patch.object(atexit, "register") as register_mock: + assert register_mock.call_count == 0 + with self._make_one(): + assert register_mock.call_count == 1 + + def test_timeout_args_passed(self): + """batch_operation_timeout and batch_attempt_timeout should be used + in api calls""" + with mock.patch.object( + CrossSync._Sync_Impl, + "_MutateRowsOperation", + return_value=CrossSync._Sync_Impl.Mock(), + ) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout + + @pytest.mark.parametrize( + "limit,in_e,start_e,end_e", + [ + (10, 0, (10, 0), (10, 0)), + (1, 10, (0, 0), (1, 1)), + (10, 1, (0, 0), (1, 0)), + (10, 10, (0, 0), (10, 0)), + (10, 11, (0, 0), (10, 1)), + (3, 20, (0, 0), (3, 3)), + (10, 20, (0, 0), (10, 10)), + (10, 21, (0, 0), (10, 10)), + (2, 1, (2, 0), (2, 1)), + (2, 1, (1, 0), (2, 0)), + (2, 2, (1, 0), (2, 1)), + (3, 1, (3, 1), (3, 2)), + (3, 3, (3, 1), (3, 3)), + (1000, 5, (999, 0), (1000, 4)), + (1000, 5, (0, 0), (5, 0)), + (1000, 5, (1000, 0), (1000, 5)), + ], + ) + def test__add_exceptions(self, limit, in_e, start_e, end_e): + """Test that the _add_exceptions function properly updates the + _oldest_exceptions and _newest_exceptions lists + Args: + - limit: the _exception_list_limit representing the max size of either list + - in_e: size of list of exceptions to send to _add_exceptions + - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions + - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions + """ + from collections import deque + + input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] + mock_batcher = mock.Mock() + mock_batcher._oldest_exceptions = [ + RuntimeError(f"starting mock {i}") for i in range(start_e[0]) + ] + mock_batcher._newest_exceptions = deque( + [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], + maxlen=limit, + ) + mock_batcher._exception_list_limit = limit + mock_batcher._exceptions_since_last_raise = 0 + self._get_target_class()._add_exceptions(mock_batcher, input_list) + assert len(mock_batcher._oldest_exceptions) == end_e[0] + assert len(mock_batcher._newest_exceptions) == end_e[1] + assert mock_batcher._exceptions_since_last_raise == in_e + oldest_list_diff = end_e[0] - start_e[0] + newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) + for i in range(oldest_list_diff): + assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] + for i in range(1, newest_list_diff + 1): + assert mock_batcher._newest_exceptions[-i] == input_list[-i] + + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors(self, input_retryables, expected_retryables): + """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer.""" + with mock.patch.object( + google.api_core.retry, "if_exception_type" + ) as predicate_builder_mock: + with mock.patch.object( + CrossSync._Sync_Impl, "retry_target" + ) as retry_fn_mock: + table = None + with mock.patch("asyncio.create_task"): + table = CrossSync._Sync_Impl.Table(mock.Mock(), "instance", "table") + with self._make_one( + table, batch_retryable_errors=input_retryables + ) as instance: + assert instance._retryable_errors == expected_retryables + expected_predicate = expected_retryables.__contains__ + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + mutation = self._make_mutation(count=1, size=1) + instance._execute_mutate_rows([mutation]) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, _MutateRowsIncomplete + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate + + def test_large_batch_write(self): + """Test that a large batch of mutations can be written""" + import math + + num_mutations = 10000 + flush_limit = 1000 + mutations = [self._make_mutation(count=1, size=1)] * num_mutations + with self._make_one(flush_limit_mutation_count=flush_limit) as instance: + operation_mock = mock.Mock() + rpc_call_mock = CrossSync._Sync_Impl.Mock() + operation_mock().start = rpc_call_mock + CrossSync._Sync_Impl._MutateRowsOperation = operation_mock + for m in mutations: + instance.append(m) + expected_calls = math.ceil(num_mutations / flush_limit) + assert rpc_call_mock.call_count == expected_calls + assert instance._entries_processed_since_last_raise == num_mutations + assert len(instance._staged_entries) == 0 diff --git a/tests/unit/data/_sync_autogen/test_read_rows_acceptance.py b/tests/unit/data/_sync_autogen/test_read_rows_acceptance.py new file mode 100644 index 000000000..8ceb0daf7 --- /dev/null +++ b/tests/unit/data/_sync_autogen/test_read_rows_acceptance.py @@ -0,0 +1,328 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +import os +import warnings +import pytest +import mock +from itertools import zip_longest +from google.cloud.bigtable_v2 import ReadRowsResponse +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.row import Row +from ...v2_client.test_row_merger import ReadRowsTest, TestFile +from google.cloud.bigtable.data._cross_sync import CrossSync + + +class TestReadRowsAcceptance: + @staticmethod + def _get_operation_class(): + return CrossSync._Sync_Impl._ReadRowsOperation + + @staticmethod + def _get_client_class(): + return CrossSync._Sync_Impl.DataClient + + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + ) + return results + + @staticmethod + def _coro_wrapper(stream): + return stream + + def _process_chunks(self, *chunks): + def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + results = [] + for row in merger: + results.append(row) + return results + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_row_merger_scenario(self, test_case: ReadRowsTest): + def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_read_rows_scenario(self, test_case: ReadRowsTest): + def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + def __iter__(self): + return self + + def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise CrossSync._Sync_Impl.StopIteration + + def __next__(self): + return self.__anext__() + + def cancel(self): + pass + + return mock_stream(chunk_list) + + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client = self._get_client_class()() + try: + table = client.get_table("instance", "table") + results = [] + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.return_value = _make_gapic_stream(test_case.chunks) + for row in table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + def test_out_of_order_rows(self): + def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + with pytest.raises(InvalidChunk): + for _ in merger: + pass + + def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + def test_missing_family(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 9bdf17c27..ea93fed55 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/data/execute_query/_sync_autogen/__init__.py b/tests/unit/data/execute_query/_sync_autogen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/execute_query/_sync_autogen/test_query_iterator.py b/tests/unit/data/execute_query/_sync_autogen/test_query_iterator.py new file mode 100644 index 000000000..77a28ea92 --- /dev/null +++ b/tests/unit/data/execute_query/_sync_autogen/test_query_iterator.py @@ -0,0 +1,163 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +import pytest +import concurrent.futures +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes +from google.cloud.bigtable.data._cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock + + +class MockIterator: + def __init__(self, values, delay=None): + self._values = values + self.idx = 0 + self._delay = delay + + def __iter__(self): + return self + + def __next__(self): + if self.idx >= len(self._values): + raise CrossSync._Sync_Impl.StopIteration + if self._delay is not None: + CrossSync._Sync_Impl.sleep(self._delay) + value = self._values[self.idx] + self.idx += 1 + return value + + +class TestQueryIterator: + @staticmethod + def _target_class(): + return CrossSync._Sync_Impl.ExecuteQueryIterator + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + @pytest.fixture + def proto_byte_stream(self): + proto_rows = [ + proto_rows_bytes({"int_value": 1}, {"int_value": 2}), + proto_rows_bytes({"int_value": 3}, {"int_value": 4}), + proto_rows_bytes({"int_value": 5}, {"int_value": 6}), + ] + messages = [ + *split_bytes_into_chunks(proto_rows[0], num_chunks=2), + *split_bytes_into_chunks(proto_rows[1], num_chunks=3), + proto_rows[2], + ] + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": { + "columns": [ + {"name": "test1", "type_": TYPE_INT}, + {"name": "test2", "type_": TYPE_INT}, + ] + } + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[0]}} + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[1]}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[2]}} + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[3]}} + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[4]}, + "resume_token": b"token2", + } + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[5]}, + "resume_token": b"token3", + } + ), + ] + return stream + + def test_iterator(self, proto_byte_stream): + client_mock = mock.Mock() + client_mock._register_instance = CrossSync._Sync_Impl.Mock() + client_mock._remove_instance_registration = CrossSync._Sync_Impl.Mock() + client_mock._executor = concurrent.futures.ThreadPoolExecutor() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + with mock.patch.object( + CrossSync._Sync_Impl, + "retry_target_stream", + return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + result = [] + for value in iterator: + result.append(tuple(value)) + assert result == [(1, 2), (3, 4), (5, 6)] + assert iterator.is_closed + client_mock._register_instance.assert_called_once() + client_mock._remove_instance_registration.assert_called_once() + assert mock_async_iterator.idx == len(proto_byte_stream) + + def test_iterator_awaits_metadata(self, proto_byte_stream): + client_mock = mock.Mock() + client_mock._register_instance = CrossSync._Sync_Impl.Mock() + client_mock._remove_instance_registration = CrossSync._Sync_Impl.Mock() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + with mock.patch.object( + CrossSync._Sync_Impl, + "retry_target_stream", + return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + iterator.metadata() + assert mock_async_iterator.idx == 1 diff --git a/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py index 914a0920a..f7159fb71 100644 --- a/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py +++ b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py @@ -47,7 +47,7 @@ (b"3", "bytes_value", "bytes_type", b"3"), (True, "bool_value", "bool_type", True), ( - datetime.datetime.fromtimestamp(timestamp), + datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc), "timestamp_value", "timestamp_type", dt_nanos_zero, diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 588890265..39db06689 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -81,7 +81,7 @@ def test_attempt_timeout_w_sleeps(self): sleep_time = 0.1 for i in range(3): found_value = next(generator) - assert abs(found_value - expected_value) < 0.001 + assert abs(found_value - expected_value) < 0.1 sleep(sleep_time) expected_value -= sleep_time From 7a6d0c5f18729683a9cc1e1a51589b84cd4a463b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 18 Dec 2024 15:31:03 -0600 Subject: [PATCH 4/5] chore(tests): sync client verification tests (#1046) --- .cross_sync/README.md | 2 + .github/workflows/conformance.yaml | 12 +- .kokoro/conformance.sh | 10 +- noxfile.py | 2 +- test_proxy/README.md | 2 +- .../client_handler_data_sync_autogen.py | 185 ++++++++++++++++++ test_proxy/run_tests.sh | 17 +- test_proxy/test_proxy.py | 5 +- tests/unit/data/test_sync_up_to_date.py | 99 ++++++++++ 9 files changed, 319 insertions(+), 15 deletions(-) create mode 100644 test_proxy/handlers/client_handler_data_sync_autogen.py create mode 100644 tests/unit/data/test_sync_up_to_date.py diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 18a9aafdf..0d8a1cf8c 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -66,6 +66,8 @@ Generation can be initiated using `nox -s generate_sync` from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` annotation, and generate a sync version of classes marked with `@CrossSync.convert_sync` at the output path. +There is a unit test at `tests/unit/data/test_sync_up_to_date.py` that verifies that the generated code is up to date + ## Architecture CrossSync is made up of two parts: diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index 448e1cc3a..8445240c3 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -26,7 +26,15 @@ jobs: matrix: test-version: [ "v0.0.2" ] py-version: [ 3.8 ] - client-type: [ "async", "legacy" ] + client-type: [ "async", "sync", "legacy" ] + include: + - client-type: "sync" + # sync client does not support concurrent streams + test_args: "-skip _Generic_MultiStream" + - client-type: "legacy" + # legacy client is synchronous and does not support concurrent streams + # legacy client does not expose mutate_row. Disable those tests + test_args: "-skip _Generic_MultiStream -skip TestMutateRow_" fail-fast: false name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: @@ -53,4 +61,6 @@ jobs: env: CLIENT_TYPE: ${{ matrix.client-type }} PYTHONUNBUFFERED: 1 + TEST_ARGS: ${{ matrix.test_args }} + PROXY_PORT: 9999 diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh index e85fc1394..fd585142e 100644 --- a/.kokoro/conformance.sh +++ b/.kokoro/conformance.sh @@ -19,16 +19,7 @@ set -eo pipefail ## cd to the parent directory, i.e. the root of the git repo cd $(dirname $0)/.. -PROXY_ARGS="" -TEST_ARGS="" -if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then - echo "Using legacy client" - # legacy client does not expose mutate_row. Disable those tests - TEST_ARGS="-skip TestMutateRow_" -fi - # Build and start the proxy in a separate process -PROXY_PORT=9999 pushd test_proxy nohup python test_proxy.py --port $PROXY_PORT --client_type=$CLIENT_TYPE & proxyPID=$! @@ -42,6 +33,7 @@ function cleanup() { trap cleanup EXIT # Run the conformance test +echo "running tests with args: $TEST_ARGS" pushd cloud-bigtable-clients-test/tests eval "go test -v -proxy_addr=:$PROXY_PORT $TEST_ARGS" RETURN_CODE=$? diff --git a/noxfile.py b/noxfile.py index 8576fed85..548bfd0ec 100644 --- a/noxfile.py +++ b/noxfile.py @@ -298,7 +298,7 @@ def system_emulated(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -@nox.parametrize("client_type", ["async"]) +@nox.parametrize("client_type", ["async", "sync", "legacy"]) def conformance(session, client_type): # install dependencies constraints_path = str( diff --git a/test_proxy/README.md b/test_proxy/README.md index 266fba7cd..5c87c729a 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -31,7 +31,7 @@ python test_proxy.py --port 8080 ``` By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. -Valid options are `async` and `legacy`. +Valid options are `async`, `sync`, and `legacy`. ``` python test_proxy.py --client_type=legacy diff --git a/test_proxy/handlers/client_handler_data_sync_autogen.py b/test_proxy/handlers/client_handler_data_sync_autogen.py new file mode 100644 index 000000000..eabae0ffa --- /dev/null +++ b/test_proxy/handlers/client_handler_data_sync_autogen.py @@ -0,0 +1,185 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is automatically generated by CrossSync. Do not edit manually. + +""" +This module contains the client handler process for proxy_server.py. +""" +import os +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data._cross_sync import CrossSync +from client_handler_data_async import error_safe + + +class TestProxyClientHandler: + """ + Implements the same methods as the grpc server, but handles the client + library side of the request. + + Requests received in TestProxyGrpcServer are converted to a dictionary, + and supplied to the TestProxyClientHandler methods as kwargs. + The client response is then returned back to the TestProxyGrpcServer + """ + + def __init__( + self, + data_target=None, + project_id=None, + instance_id=None, + app_profile_id=None, + per_operation_timeout=None, + **kwargs + ): + self.closed = False + os.environ[BIGTABLE_EMULATOR] = data_target + self.client = CrossSync._Sync_Impl.DataClient(project=project_id) + self.instance_id = instance_id + self.app_profile_id = app_profile_id + self.per_operation_timeout = per_operation_timeout + + def close(self): + self.closed = True + + @error_safe + async def ReadRows(self, request, **kwargs): + table_id = request.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result_list = table.read_rows(request, **kwargs) + serialized_response = [row._to_dict() for row in result_list] + return serialized_response + + @error_safe + async def ReadRow(self, row_key, **kwargs): + table_id = kwargs.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result_row = table.read_row(row_key, **kwargs) + if result_row: + return result_row._to_dict() + else: + return "None" + + @error_safe + async def MutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + mutations = [Mutation._from_dict(d) for d in request["mutations"]] + table.mutate_row(row_key, mutations, **kwargs) + return "OK" + + @error_safe + async def BulkMutateRows(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import RowMutationEntry + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + entry_list = [ + RowMutationEntry._from_dict(entry) for entry in request["entries"] + ] + table.bulk_mutate_rows(entry_list, **kwargs) + return "OK" + + @error_safe + async def CheckAndMutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation, SetCell + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + true_mutations = [] + for mut_dict in request.get("true_mutations", []): + try: + true_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + mutation = SetCell("", "", "", 0) + true_mutations.append(mutation) + false_mutations = [] + for mut_dict in request.get("false_mutations", []): + try: + false_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + false_mutations.append(SetCell("", "", "", 0)) + predicate_filter = request.get("predicate_filter", None) + result = table.check_and_mutate_row( + row_key, + predicate_filter, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + **kwargs + ) + return result + + @error_safe + async def ReadModifyWriteRow(self, request, **kwargs): + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + rules = [] + for rule_dict in request.get("rules", []): + qualifier = rule_dict["column_qualifier"] + if "append_value" in rule_dict: + new_rule = AppendValueRule( + rule_dict["family_name"], qualifier, rule_dict["append_value"] + ) + else: + new_rule = IncrementRule( + rule_dict["family_name"], qualifier, rule_dict["increment_amount"] + ) + rules.append(new_rule) + result = table.read_modify_write_row(row_key, rules, **kwargs) + if result: + return result._to_dict() + else: + return "None" + + @error_safe + async def SampleRowKeys(self, request, **kwargs): + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result = table.sample_row_keys(**kwargs) + return result diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index c2e9c6312..b6f1291a6 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -27,7 +27,7 @@ fi SCRIPT_DIR=$(realpath $(dirname "$0")) cd $SCRIPT_DIR -export PROXY_SERVER_PORT=50055 +export PROXY_SERVER_PORT=$(shuf -i 50000-60000 -n 1) # download test suite if [ ! -d "cloud-bigtable-clients-test" ]; then @@ -43,6 +43,19 @@ function finish { } trap finish EXIT +if [[ $CLIENT_TYPE == "legacy" ]]; then + echo "Using legacy client" + # legacy client does not expose mutate_row. Disable those tests + TEST_ARGS="-skip TestMutateRow_" +fi + +if [[ $CLIENT_TYPE != "async" ]]; then + echo "Using legacy client" + # sync and legacy client do not support concurrent streams + TEST_ARGS="$TEST_ARGS -skip _Generic_MultiStream " +fi + # run tests pushd cloud-bigtable-clients-test/tests -go test -v -proxy_addr=:$PROXY_SERVER_PORT +echo "Running with $TEST_ARGS" +go test -v -proxy_addr=:$PROXY_SERVER_PORT $TEST_ARGS diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py index 9e03f1e5c..793500768 100644 --- a/test_proxy/test_proxy.py +++ b/test_proxy/test_proxy.py @@ -114,6 +114,9 @@ def format_dict(input_obj): if client_type == "legacy": import client_handler_legacy client = client_handler_legacy.LegacyTestProxyClientHandler(**json_data) + elif client_type == "sync": + import client_handler_data_sync_autogen + client = client_handler_data_sync_autogen.TestProxyClientHandler(**json_data) else: client = client_handler_data_async.TestProxyClientHandlerAsync(**json_data) client_map[client_id] = client @@ -150,7 +153,7 @@ def client_handler_process(request_q, queue_pool, client_type="async"): p = argparse.ArgumentParser() p.add_argument("--port", dest='port', default="50055") -p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "legacy"]) +p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "sync", "legacy"]) if __name__ == "__main__": port = p.parse_args().port diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py new file mode 100644 index 000000000..492d35ddf --- /dev/null +++ b/tests/unit/data/test_sync_up_to_date.py @@ -0,0 +1,99 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import hashlib +import pytest +import ast +import re +from difflib import unified_diff + +# add cross_sync to path +test_dir_name = os.path.dirname(__file__) +repo_root = os.path.join(test_dir_name, "..", "..", "..") +cross_sync_path = os.path.join(repo_root, ".cross_sync") +sys.path.append(cross_sync_path) + +from generate import convert_files_in_dir, CrossSyncOutputFile # noqa: E402 + +sync_files = list(convert_files_in_dir(repo_root)) + + +def test_found_files(): + """ + Make sure sync_test is populated with some of the files we expect to see, + to ensure that later tests are actually running. + """ + assert len(sync_files) > 0, "No sync files found" + assert len(sync_files) > 10, "Unexpectedly few sync files found" + # test for key files + outputs = [os.path.basename(f.output_path) for f in sync_files] + assert "client.py" in outputs + assert "execute_query_iterator.py" in outputs + assert "test_client.py" in outputs + assert "test_system_autogen.py" in outputs, "system tests not found" + assert ( + "client_handler_data_sync_autogen.py" in outputs + ), "test proxy handler not found" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" +) +@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) +def test_sync_up_to_date(sync_file): + """ + Generate a fresh copy of each cross_sync file, and compare hashes with the existing file. + + If this test fails, run `nox -s generate_sync` to update the sync files. + """ + path = sync_file.output_path + new_render = sync_file.render(with_formatter=True, save_to_disk=False) + found_render = CrossSyncOutputFile( + output_path="", ast_tree=ast.parse(open(path).read()), header=sync_file.header + ).render(with_formatter=True, save_to_disk=False) + # compare by content + diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") + diff_str = "\n".join(diff) + assert ( + not diff_str + ), f"Found differences. Run `nox -s generate_sync` to update:\n{diff_str}" + # compare by hash + new_hash = hashlib.md5(new_render.encode()).hexdigest() + found_hash = hashlib.md5(found_render.encode()).hexdigest() + assert new_hash == found_hash, f"md5 mismatch for {path}" + + +@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) +def test_verify_headers(sync_file): + license_regex = r""" + \#\ Copyright\ \d{4}\ Google\ LLC\n + \#\n + \#\ Licensed\ under\ the\ Apache\ License,\ Version\ 2\.0\ \(the\ \"License\"\);\n + \#\ you\ may\ not\ use\ this\ file\ except\ in\ compliance\ with\ the\ License\.\n + \#\ You\ may\ obtain\ a\ copy\ of\ the\ License\ at\ + \#\n + \#\s+http:\/\/www\.apache\.org\/licenses\/LICENSE-2\.0\n + \#\n + \#\ Unless\ required\ by\ applicable\ law\ or\ agreed\ to\ in\ writing,\ software\n + \#\ distributed\ under\ the\ License\ is\ distributed\ on\ an\ \"AS\ IS\"\ BASIS,\n + \#\ WITHOUT\ WARRANTIES\ OR\ CONDITIONS\ OF\ ANY\ KIND,\ either\ express\ or\ implied\.\n + \#\ See\ the\ License\ for\ the\ specific\ language\ governing\ permissions\ and\n + \#\ limitations\ under\ the\ License\. + """ + pattern = re.compile(license_regex, re.VERBOSE) + + with open(sync_file.output_path, "r") as f: + content = f.read() + assert pattern.search(content), "Missing license header" From 9c6f79d52e5bb4644cb4c483dd705b3cc0b41856 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:42:31 -0800 Subject: [PATCH 5/5] chore(python): update dependencies in .kokoro/docker/docs (#1052) Source-Link: https://github.com/googleapis/synthtool/commit/e808c98e1ab7eec3df2a95a05331619f7001daef Post-Processor: gcr.io/cloud-devrel-public-resources/owlbot-python:latest@sha256:8e3e7e18255c22d1489258d0374c901c01f9c4fd77a12088670cd73d580aa737 Co-authored-by: Owl Bot --- .github/.OwlBot.lock.yaml | 4 +- .github/release-trigger.yml | 2 +- .kokoro/docker/docs/requirements.txt | 66 ++++++++++++++++++++-------- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index 2fda9335f..26306af66 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -13,5 +13,5 @@ # limitations under the License. docker: image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:5cddfe2fb5019bbf78335bc55f15bc13e18354a56b3ff46e1834f8e540807f05 -# created: 2024-10-31T01:41:07.349286254Z \ No newline at end of file + digest: sha256:8e3e7e18255c22d1489258d0374c901c01f9c4fd77a12088670cd73d580aa737 +# created: 2024-12-17T00:59:58.625514486Z diff --git a/.github/release-trigger.yml b/.github/release-trigger.yml index 4bb79e58e..0bbdd8e4c 100644 --- a/.github/release-trigger.yml +++ b/.github/release-trigger.yml @@ -1,2 +1,2 @@ enabled: true -multiScmName: +multiScmName: python-bigtable diff --git a/.kokoro/docker/docs/requirements.txt b/.kokoro/docker/docs/requirements.txt index 66eacc82f..f99a5c4aa 100644 --- a/.kokoro/docker/docs/requirements.txt +++ b/.kokoro/docker/docs/requirements.txt @@ -1,16 +1,16 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --allow-unsafe --generate-hashes requirements.in +# pip-compile --allow-unsafe --generate-hashes synthtool/gcp/templates/python_library/.kokoro/docker/docs/requirements.in # -argcomplete==3.5.1 \ - --hash=sha256:1a1d148bdaa3e3b93454900163403df41448a248af01b6e849edc5ac08e6c363 \ - --hash=sha256:eb1ee355aa2557bd3d0145de7b06b2a45b0ce461e1e7813f5d066039ab4177b4 +argcomplete==3.5.2 \ + --hash=sha256:036d020d79048a5d525bc63880d7a4b8d1668566b8a76daf1144c0bbe0f63472 \ + --hash=sha256:23146ed7ac4403b70bd6026402468942ceba34a6732255b9edf5b7354f68a6bb # via nox -colorlog==6.8.2 \ - --hash=sha256:3e3e079a41feb5a1b64f978b5ea4f46040a94f11f0e8bbb8261e3dbbeca64d44 \ - --hash=sha256:4dcbb62368e2800cb3c5abd348da7e53f6c362dda502ec27c560b2e58a66bd33 +colorlog==6.9.0 \ + --hash=sha256:5906e71acd67cb07a71e779c47c4bcb45fb8c2993eebe9e5adcd6a6f1b283eff \ + --hash=sha256:bfba54a1b93b94f54e1f4fe48395725a3d92fd2a4af702f6bd70946bdc0c6ac2 # via nox distlib==0.3.9 \ --hash=sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87 \ @@ -23,20 +23,50 @@ filelock==3.16.1 \ nox==2024.10.9 \ --hash=sha256:1d36f309a0a2a853e9bccb76bbef6bb118ba92fa92674d15604ca99adeb29eab \ --hash=sha256:7aa9dc8d1c27e9f45ab046ffd1c3b2c4f7c91755304769df231308849ebded95 - # via -r requirements.in -packaging==24.1 \ - --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ - --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 + # via -r synthtool/gcp/templates/python_library/.kokoro/docker/docs/requirements.in +packaging==24.2 \ + --hash=sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759 \ + --hash=sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f # via nox platformdirs==4.3.6 \ --hash=sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907 \ --hash=sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb # via virtualenv -tomli==2.0.2 \ - --hash=sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38 \ - --hash=sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed +tomli==2.2.1 \ + --hash=sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6 \ + --hash=sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd \ + --hash=sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c \ + --hash=sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b \ + --hash=sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8 \ + --hash=sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6 \ + --hash=sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77 \ + --hash=sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff \ + --hash=sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea \ + --hash=sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192 \ + --hash=sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249 \ + --hash=sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee \ + --hash=sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4 \ + --hash=sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98 \ + --hash=sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8 \ + --hash=sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4 \ + --hash=sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281 \ + --hash=sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744 \ + --hash=sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69 \ + --hash=sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13 \ + --hash=sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140 \ + --hash=sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e \ + --hash=sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e \ + --hash=sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc \ + --hash=sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff \ + --hash=sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec \ + --hash=sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2 \ + --hash=sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222 \ + --hash=sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106 \ + --hash=sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272 \ + --hash=sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a \ + --hash=sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7 # via nox -virtualenv==20.26.6 \ - --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ - --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 +virtualenv==20.28.0 \ + --hash=sha256:23eae1b4516ecd610481eda647f3a7c09aea295055337331bb4e6892ecce47b0 \ + --hash=sha256:2c9c3262bb8e7b87ea801d715fae4495e6032450c71d2309be9550e7364049aa # via nox