From 45bc8c4a0fe567ce5e0126a1a70e7eb3dca93e92 Mon Sep 17 00:00:00 2001 From: Kajetan Boroszko Date: Thu, 8 Aug 2024 22:12:25 +0200 Subject: [PATCH] feat: async execute query client (#1011) Co-authored-by: Mateusz Walkiewicz Co-authored-by: Owl Bot --- google/cloud/bigtable/data/__init__.py | 2 + .../bigtable/data/_async/_mutate_rows.py | 4 +- .../cloud/bigtable/data/_async/_read_rows.py | 3 +- google/cloud/bigtable/data/_async/client.py | 234 ++++-- google/cloud/bigtable/data/_helpers.py | 38 +- google/cloud/bigtable/data/exceptions.py | 8 + .../bigtable/data/execute_query/__init__.py | 38 + .../data/execute_query/_async/__init__.py | 13 + .../_async/execute_query_iterator.py | 211 ++++++ .../data/execute_query/_byte_cursor.py | 144 ++++ .../execute_query/_parameters_formatting.py | 118 +++ .../_query_result_parsing_utils.py | 133 ++++ .../bigtable/data/execute_query/_reader.py | 149 ++++ .../bigtable/data/execute_query/metadata.py | 354 +++++++++ .../bigtable/data/execute_query/values.py | 116 +++ google/cloud/bigtable/helpers.py | 31 + google/cloud/bigtable/instance.py | 1 + .../data_client/data_client_snippets_async.py | 45 +- .../data_client_snippets_async_test.py | 5 + testing/constraints-3.8.txt | 1 + tests/_testing.py | 36 + tests/system/data/test_execute_query_async.py | 288 +++++++ tests/system/data/test_execute_query_utils.py | 272 +++++++ tests/unit/_testing.py | 16 + tests/unit/data/_async/__init__.py | 13 + tests/unit/data/_testing.py | 18 + tests/unit/data/execute_query/__init__.py | 13 + .../data/execute_query/_async/__init__.py | 13 + .../data/execute_query/_async/_testing.py | 36 + .../_async/test_query_iterator.py | 156 ++++ tests/unit/data/execute_query/_testing.py | 17 + .../data/execute_query/test_byte_cursor.py | 149 ++++ .../test_execute_query_parameters_parsing.py | 134 ++++ .../test_query_result_parsing_utils.py | 715 ++++++++++++++++++ .../test_query_result_row_reader.py | 310 ++++++++ tests/unit/data/test__helpers.py | 25 +- tests/unit/data/test_helpers.py | 45 ++ tests/unit/v2_client/_testing.py | 3 + tests/unit/v2_client/test_instance.py | 26 + 39 files changed, 3855 insertions(+), 78 deletions(-) create mode 100644 google/cloud/bigtable/data/execute_query/__init__.py create mode 100644 google/cloud/bigtable/data/execute_query/_async/__init__.py create mode 100644 google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py create mode 100644 google/cloud/bigtable/data/execute_query/_byte_cursor.py create mode 100644 google/cloud/bigtable/data/execute_query/_parameters_formatting.py create mode 100644 google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py create mode 100644 google/cloud/bigtable/data/execute_query/_reader.py create mode 100644 google/cloud/bigtable/data/execute_query/metadata.py create mode 100644 google/cloud/bigtable/data/execute_query/values.py create mode 100644 google/cloud/bigtable/helpers.py create mode 100644 tests/_testing.py create mode 100644 tests/system/data/test_execute_query_async.py create mode 100644 tests/system/data/test_execute_query_utils.py create mode 100644 tests/unit/_testing.py create mode 100644 tests/unit/data/_async/__init__.py create mode 100644 tests/unit/data/_testing.py create mode 100644 tests/unit/data/execute_query/__init__.py create mode 100644 tests/unit/data/execute_query/_async/__init__.py create mode 100644 tests/unit/data/execute_query/_async/_testing.py create mode 100644 tests/unit/data/execute_query/_async/test_query_iterator.py create mode 100644 tests/unit/data/execute_query/_testing.py create mode 100644 tests/unit/data/execute_query/test_byte_cursor.py create mode 100644 tests/unit/data/execute_query/test_execute_query_parameters_parsing.py create mode 100644 tests/unit/data/execute_query/test_query_result_parsing_utils.py create mode 100644 tests/unit/data/execute_query/test_query_result_row_reader.py create mode 100644 tests/unit/data/test_helpers.py diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 5229f8021..68dc22891 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -39,6 +39,7 @@ from google.cloud.bigtable.data.exceptions import RetryExceptionGroup from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup +from google.cloud.bigtable.data.exceptions import ParameterTypeInferenceFailed from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import RowKeySamples @@ -68,6 +69,7 @@ "RetryExceptionGroup", "MutationsExceptionGroup", "ShardedReadRowsExceptionGroup", + "ParameterTypeInferenceFailed", "ShardedQuery", "TABLE_DEFAULT", ) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 99b9944cd..465378aa4 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -84,7 +84,9 @@ def __init__( f"all entries. Found {total_mutations}." ) # create partial function to pass to trigger rpc call - metadata = _make_metadata(table.table_name, table.app_profile_id) + metadata = _make_metadata( + table.table_name, table.app_profile_id, instance_name=None + ) self._gapic_fn = functools.partial( gapic_client.mutate_rows, table_name=table.table_name, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 78cb7a991..6034ae6cf 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -102,8 +102,7 @@ def __init__( self.table = table self._predicate = retries.if_exception_type(*retryable_exceptions) self._metadata = _make_metadata( - table.table_name, - table.app_profile_id, + table.table_name, table.app_profile_id, instance_name=None ) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 34fdf847a..600937df8 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -15,74 +15,88 @@ from __future__ import annotations +import asyncio +from functools import partial +import os +import random +import sys +import time from typing import ( - cast, + TYPE_CHECKING, Any, AsyncIterable, + Dict, Optional, - Set, Sequence, - TYPE_CHECKING, + Set, + Union, + cast, ) - -import asyncio -import grpc -import time import warnings -import sys -import random -import os -from functools import partial +from google.api_core import client_options as client_options_lib +from google.api_core import retry as retries +from google.api_core.exceptions import Aborted, DeadlineExceeded, ServiceUnavailable +import google.auth._default +import google.auth.credentials +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore +import grpc +from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( + ExecuteQueryIteratorAsync, +) +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data._async.mutations_batcher import ( + _MB_SIZE, + MutationsBatcherAsync, +) +from google.cloud.bigtable.data._helpers import ( + _CONCURRENCY_LIMIT, + TABLE_DEFAULT, + _attempt_timeout_generator, + _get_error_type, + _get_retryable_errors, + _get_timeouts, + _make_metadata, + _retry_exception_factory, + _validate_timeouts, + _WarmedInstanceKey, +) +from google.cloud.bigtable.data.exceptions import ( + FailedQueryShardError, + ShardedReadRowsExceptionGroup, +) +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry +from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.row_filters import ( + CellsRowLimitFilter, + RowFilter, + RowFilterChain, + StripValueTransformerFilter, +) +from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType +from google.cloud.bigtable.data.execute_query.metadata import SqlType +from google.cloud.bigtable.data.execute_query._parameters_formatting import ( + _format_execute_query_params, +) +from google.cloud.bigtable_v2.services.bigtable.async_client import ( + DEFAULT_CLIENT_INFO, + BigtableAsyncClient, +) from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, PooledChannel, ) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest -from google.cloud.client import ClientWithProject -from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore -from google.api_core import retry as retries -from google.api_core.exceptions import DeadlineExceeded -from google.api_core.exceptions import ServiceUnavailable -from google.api_core.exceptions import Aborted -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - -import google.auth.credentials -import google.auth._default -from google.api_core import client_options as client_options_lib -from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT -from google.cloud.bigtable.data.row import Row -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.exceptions import FailedQueryShardError -from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - -from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _WarmedInstanceKey -from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT -from google.cloud.bigtable.data._helpers import _make_metadata -from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._helpers import _validate_timeouts -from google.cloud.bigtable.data._helpers import _get_retryable_errors -from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule -from google.cloud.bigtable.data.row_filters import RowFilter -from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter -from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter -from google.cloud.bigtable.data.row_filters import RowFilterChain - if TYPE_CHECKING: - from google.cloud.bigtable.data._helpers import RowKeySamples - from google.cloud.bigtable.data._helpers import ShardedQuery + from google.cloud.bigtable.data._helpers import RowKeySamples, ShardedQuery class BigtableDataClientAsync(ClientWithProject): @@ -315,7 +329,9 @@ async def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.time() - start_timestamp) - async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: + async def _register_instance( + self, instance_id: str, owner: Union[TableAsync, ExecuteQueryIteratorAsync] + ) -> None: """ Registers an instance with the client, and warms the channel pool for the instance @@ -346,7 +362,7 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: self._start_background_channel_refresh() async def _remove_instance_registration( - self, instance_id: str, owner: TableAsync + self, instance_id: str, owner: Union[TableAsync, ExecuteQueryIteratorAsync] ) -> bool: """ Removes an instance from the client's registered instances, to prevent @@ -416,6 +432,102 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + async def execute_query( + self, + query: str, + instance_id: str, + *, + parameters: Dict[str, ExecuteQueryValueType] | None = None, + parameter_types: Dict[str, SqlType.Type] | None = None, + app_profile_id: str | None = None, + operation_timeout: float = 600, + attempt_timeout: float | None = 20, + retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + ) -> "ExecuteQueryIteratorAsync": + """ + Executes an SQL query on an instance. + Returns an iterator to asynchronously stream back columns from selected rows. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + - query: Query to be run on Bigtable instance. The query can use ``@param`` + placeholders to use parameter interpolation on the server. Values for all + parameters should be provided in ``parameters``. Types of parameters are + inferred but should be provided in ``parameter_types`` if the inference is + not possible (i.e. when value can be None, an empty list or an empty dict). + - instance_id: The Bigtable instance ID to perform the query on. + instance_id is combined with the client's project to fully + specify the instance. + - parameters: Dictionary with values for all parameters used in the ``query``. + - parameter_types: Dictionary with types of parameters used in the ``query``. + Required to contain entries only for parameters whose type cannot be + detected automatically (i.e. the value can be None, an empty list or + an empty dict). + - app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to 600 seconds. + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the 20 seconds. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + Returns: + - an asynchronous iterator that yields rows returned by the query + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + warnings.warn( + "ExecuteQuery is in preview and may change in the future.", + category=RuntimeWarning, + ) + + retryable_excs = [_get_error_type(e) for e in retryable_errors] + + pb_params = _format_execute_query_params(parameters, parameter_types) + + instance_name = self._gapic_client.instance_path(self.project, instance_id) + + request_body = { + "instance_name": instance_name, + "app_profile_id": app_profile_id, + "query": query, + "params": pb_params, + "proto_format": {}, + } + + # app_profile_id should be set to an empty string for ExecuteQueryRequest only + app_profile_id_for_metadata = app_profile_id or "" + + req_metadata = _make_metadata( + table_name=None, + app_profile_id=app_profile_id_for_metadata, + instance_name=instance_name, + ) + + return ExecuteQueryIteratorAsync( + self, + instance_id, + app_profile_id, + request_body, + attempt_timeout, + operation_timeout, + req_metadata, + retryable_excs, + ) + async def __aenter__(self): self._start_background_channel_refresh() return self @@ -893,7 +1005,9 @@ async def sample_row_keys( sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) # prepare request - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ) async def execute_rpc(): results = await self.client._gapic_client.sample_row_keys( @@ -1029,7 +1143,9 @@ async def mutate_row( table_name=self.table_name, app_profile_id=self.app_profile_id, timeout=attempt_timeout, - metadata=_make_metadata(self.table_name, self.app_profile_id), + metadata=_make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ), retry=None, ) return await retries.retry_target_async( @@ -1147,7 +1263,9 @@ async def check_and_mutate_row( ): false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ) result = await self.client._gapic_client.check_and_mutate_row( true_mutations=true_case_list, false_mutations=false_case_list, @@ -1198,7 +1316,9 @@ async def read_modify_write_row( rules = [rules] if not rules: raise ValueError("rules must contain at least one item") - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ) result = await self.client._gapic_client.read_modify_write_row( rules=[rule._to_pb() for rule in rules], row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index a8fba9ef1..2d36c521f 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -16,7 +16,7 @@ """ from __future__ import annotations -from typing import Sequence, List, Tuple, TYPE_CHECKING +from typing import Sequence, List, Tuple, TYPE_CHECKING, Union import time import enum from collections import namedtuple @@ -60,15 +60,26 @@ class TABLE_DEFAULT(enum.Enum): def _make_metadata( - table_name: str, app_profile_id: str | None + table_name: str | None, app_profile_id: str | None, instance_name: str | None ) -> list[tuple[str, str]]: """ Create properly formatted gRPC metadata for requests. """ params = [] - params.append(f"table_name={table_name}") + + if table_name is not None and instance_name is not None: + raise ValueError("metadata can't contain both instance_name and table_name") + + if table_name is not None: + params.append(f"table_name={table_name}") + if instance_name is not None: + params.append(f"name={instance_name}") if app_profile_id is not None: params.append(f"app_profile_id={app_profile_id}") + if len(params) == 0: + raise ValueError( + "At least one of table_name and app_profile_id should be not None." + ) params_str = "&".join(params) return [("x-goog-request-params", params_str)] @@ -203,6 +214,22 @@ def _validate_timeouts( raise ValueError("attempt_timeout must be greater than 0") +def _get_error_type( + call_code: Union["grpc.StatusCode", int, type[Exception]] +) -> type[Exception]: + """Helper function for ensuring the object is an exception type. + If it is not, the proper GoogleAPICallError type is infered from the status + code. + + Args: + - call_code: Exception type or gRPC status code. + """ + if isinstance(call_code, type): + return call_code + else: + return type(core_exceptions.from_grpc_status(call_code, "")) + + def _get_retryable_errors( call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, table: "TableAsync", @@ -225,7 +252,4 @@ def _get_retryable_errors( elif call_codes == TABLE_DEFAULT.MUTATE_ROWS: call_codes = table.default_mutate_rows_retryable_errors - return [ - e if isinstance(e, type) else type(core_exceptions.from_grpc_status(e, "")) - for e in call_codes - ] + return [_get_error_type(e) for e in call_codes] diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py index 8d97640aa..95cd44f2c 100644 --- a/google/cloud/bigtable/data/exceptions.py +++ b/google/cloud/bigtable/data/exceptions.py @@ -311,3 +311,11 @@ def __init__( self.__cause__ = cause self.index = failed_index self.query = failed_query + + +class InvalidExecuteQueryResponse(core_exceptions.GoogleAPICallError): + """Exception raised to invalid query response data from back-end.""" + + +class ParameterTypeInferenceFailed(ValueError): + """Exception raised when query parameter types were not provided and cannot be inferred.""" diff --git a/google/cloud/bigtable/data/execute_query/__init__.py b/google/cloud/bigtable/data/execute_query/__init__.py new file mode 100644 index 000000000..94af7d1cd --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( + ExecuteQueryIteratorAsync, +) +from google.cloud.bigtable.data.execute_query.metadata import ( + Metadata, + ProtoMetadata, + SqlType, +) +from google.cloud.bigtable.data.execute_query.values import ( + ExecuteQueryValueType, + QueryResultRow, + Struct, +) + + +__all__ = [ + "ExecuteQueryValueType", + "SqlType", + "QueryResultRow", + "Struct", + "Metadata", + "ProtoMetadata", + "ExecuteQueryIteratorAsync", +] diff --git a/google/cloud/bigtable/data/execute_query/_async/__init__.py b/google/cloud/bigtable/data/execute_query/_async/__init__.py new file mode 100644 index 000000000..6d5e14bcf --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_async/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py new file mode 100644 index 000000000..3660c0b0f --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -0,0 +1,211 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Sequence, + Tuple, +) + +from google.api_core import retry as retries + +from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor +from google.cloud.bigtable.data._helpers import ( + _attempt_timeout_generator, + _retry_exception_factory, +) +from google.cloud.bigtable.data.exceptions import InvalidExecuteQueryResponse +from google.cloud.bigtable.data.execute_query.values import QueryResultRow +from google.cloud.bigtable.data.execute_query.metadata import Metadata, ProtoMetadata +from google.cloud.bigtable.data.execute_query._reader import ( + _QueryResultRowReader, + _Reader, +) +from google.cloud.bigtable_v2.types.bigtable import ( + ExecuteQueryRequest as ExecuteQueryRequestPB, +) + + +class ExecuteQueryIteratorAsync: + """ + ExecuteQueryIteratorAsync handles collecting streaming responses from the + ExecuteQuery RPC and parsing them to `QueryResultRow`s. + + ExecuteQueryIteratorAsync implements Asynchronous Iterator interface and can + be used with "async for" syntax. It is also a context manager. + + It is **not thread-safe**. It should not be used by multiple asyncio Tasks. + + Args: + client (google.cloud.bigtable.data._async.BigtableDataClientAsync): bigtable client + instance_id (str): id of the instance on which the query is executed + request_body (Dict[str, Any]): dict representing the body of the ExecuteQueryRequest + attempt_timeout (float | None): the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to 600 seconds. + operation_timeout (float): the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the 20 seconds. If None, defaults to operation_timeout. + req_metadata (Sequence[Tuple[str, str]]): metadata used while sending the gRPC request + retryable_excs (List[type[Exception]]): a list of errors that will be retried if encountered. + """ + + def __init__( + self, + client: Any, + instance_id: str, + app_profile_id: Optional[str], + request_body: Dict[str, Any], + attempt_timeout: float | None, + operation_timeout: float, + req_metadata: Sequence[Tuple[str, str]], + retryable_excs: List[type[Exception]], + ) -> None: + self._table_name = None + self._app_profile_id = app_profile_id + self._client = client + self._instance_id = instance_id + self._byte_cursor = _ByteCursor[ProtoMetadata]() + self._reader: _Reader[QueryResultRow] = _QueryResultRowReader(self._byte_cursor) + self._result_generator = self._next_impl() + self._register_instance_task = None + self._is_closed = False + self._request_body = request_body + self._attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self._async_stream = retries.retry_target_stream_async( + self._make_request_with_resume_token, + retries.if_exception_type(*retryable_excs), + retries.exponential_sleep_generator(0.01, 60, multiplier=2), + operation_timeout, + exception_factory=_retry_exception_factory, + ) + self._req_metadata = req_metadata + + try: + self._register_instance_task = asyncio.create_task( + self._client._register_instance(instance_id, self) + ) + except RuntimeError as e: + raise RuntimeError( + f"{self.__class__.__name__} must be created within an async event loop context." + ) from e + + @property + def is_closed(self): + return self._is_closed + + @property + def app_profile_id(self): + return self._app_profile_id + + @property + def table_name(self): + return self._table_name + + async def _make_request_with_resume_token(self): + """ + perfoms the rpc call using the correct resume token. + """ + resume_token = self._byte_cursor.prepare_for_new_request() + request = ExecuteQueryRequestPB( + { + **self._request_body, + "resume_token": resume_token, + } + ) + return await self._client._gapic_client.execute_query( + request, + timeout=next(self._attempt_timeout_gen), + metadata=self._req_metadata, + retry=None, + ) + + async def _await_metadata(self) -> None: + """ + If called before the first response was recieved, the first response + is awaited as part of this call. + """ + if self._byte_cursor.metadata is None: + metadata_msg = await self._async_stream.__anext__() + self._byte_cursor.consume_metadata(metadata_msg) + + async def _next_impl(self) -> AsyncIterator[QueryResultRow]: + """ + Generator wrapping the response stream which parses the stream results + and returns full `QueryResultRow`s. + """ + await self._await_metadata() + + async for response in self._async_stream: + try: + bytes_to_parse = self._byte_cursor.consume(response) + if bytes_to_parse is None: + continue + + results = self._reader.consume(bytes_to_parse) + if results is None: + continue + + except ValueError as e: + raise InvalidExecuteQueryResponse( + "Invalid ExecuteQuery response received" + ) from e + + for result in results: + yield result + await self.close() + + async def __anext__(self): + if self._is_closed: + raise StopAsyncIteration + return await self._result_generator.__anext__() + + def __aiter__(self): + return self + + async def metadata(self) -> Optional[Metadata]: + """ + Returns query metadata from the server or None if the iterator was + explicitly closed. + """ + if self._is_closed: + return None + # Metadata should be present in the first response in a stream. + if self._byte_cursor.metadata is None: + try: + await self._await_metadata() + except StopIteration: + return None + return self._byte_cursor.metadata + + async def close(self) -> None: + """ + Cancel all background tasks. Should be called all rows were processed. + """ + if self._is_closed: + return + self._is_closed = True + if self._register_instance_task is not None: + self._register_instance_task.cancel() + await self._client._remove_instance_registration(self._instance_id, self) diff --git a/google/cloud/bigtable/data/execute_query/_byte_cursor.py b/google/cloud/bigtable/data/execute_query/_byte_cursor.py new file mode 100644 index 000000000..60f23f541 --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_byte_cursor.py @@ -0,0 +1,144 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Generic, Optional, TypeVar + +from google.cloud.bigtable_v2 import ExecuteQueryResponse +from google.cloud.bigtable.data.execute_query.metadata import ( + Metadata, + _pb_metadata_to_metadata_types, +) + +MT = TypeVar("MT", bound=Metadata) # metadata type + + +class _ByteCursor(Generic[MT]): + """ + Buffers bytes from `ExecuteQuery` responses until resume_token is received or end-of-stream + is reached. :class:`google.cloud.bigtable_v2.types.bigtable.ExecuteQueryResponse` obtained from + the server should be passed to ``consume`` or ``consume_metadata`` methods and its non-None + results should be passed to appropriate + :class:`google.cloud.bigtable.execute_query_reader._Reader` for parsing gathered bytes. + + This class consumes data obtained externally to be usable in both sync and async clients. + + See :class:`google.cloud.bigtable.execute_query_reader._Reader` for more context. + """ + + def __init__(self): + self._metadata: Optional[MT] = None + self._buffer = bytearray() + self._resume_token = None + self._last_response_results_field = None + + @property + def metadata(self) -> Optional[MT]: + """ + Returns: + Metadata or None: Metadata read from the first response of the stream + or None if no response was consumed yet. + """ + return self._metadata + + def prepare_for_new_request(self): + """ + Prepares this ``_ByteCursor`` for retrying an ``ExecuteQuery`` request. + + Clears internal buffers of this ``_ByteCursor`` and returns last received + ``resume_token`` to be used in retried request. + + This is the only method that returns ``resume_token`` to the user. + Returning the token to the user is tightly coupled with clearing internal + buffers to prevent accidental retry without clearing the state, what would + cause invalid results. ``resume_token`` are not needed in other cases, + thus they is no separate getter for it. + + Returns: + bytes: Last received resume_token. + """ + self._buffer = bytearray() + # metadata is sent in the first response in a stream, + # if we've already received one, but it was not already commited + # by a subsequent resume_token, then we should clear it as well. + if not self._resume_token: + self._metadata = None + + return self._resume_token + + def consume_metadata(self, response: ExecuteQueryResponse) -> None: + """ + Reads metadata from first response of ``ExecuteQuery`` responses stream. + Should be called only once. + + Args: + response (google.cloud.bigtable_v2.types.bigtable.ExecuteQueryResponse): First response + from the stream. + + Raises: + ValueError: If this method was already called or if metadata received from the server + cannot be parsed. + """ + if self._metadata is not None: + raise ValueError("Invalid state - metadata already consumed") + + if "metadata" in response: + metadata: Any = _pb_metadata_to_metadata_types(response.metadata) + self._metadata = metadata + else: + raise ValueError("Invalid parameter - response without metadata") + + return None + + def consume(self, response: ExecuteQueryResponse) -> Optional[bytes]: + """ + Reads results bytes from an ``ExecuteQuery`` response and adds them to a buffer. + + If the response contains a ``resume_token``: + - the ``resume_token`` is saved in this ``_ByteCursor``, and + - internal buffers are flushed and returned to the caller. + + ``resume_token`` is not available directly, but can be retrieved by calling + :meth:`._ByteCursor.prepare_for_new_request` when preparing to retry a request. + + Args: + response (google.cloud.bigtable_v2.types.bigtable.ExecuteQueryResponse): + Response obtained from the stream. + + Returns: + bytes or None: bytes if buffers were flushed or None otherwise. + + Raises: + ValueError: If provided ``ExecuteQueryResponse`` is not valid + or contains bytes representing response of a different kind than previously + processed responses. + """ + response_pb = response._pb # proto-plus attribute retrieval is slow. + + if response_pb.HasField("results"): + results = response_pb.results + if results.HasField("proto_rows_batch"): + self._buffer.extend(results.proto_rows_batch.batch_data) + + if results.resume_token: + self._resume_token = results.resume_token + + if self._buffer: + return_value = memoryview(self._buffer) + self._buffer = bytearray() + return return_value + elif response_pb.HasField("metadata"): + self.consume_metadata(response) + else: + raise ValueError(f"Invalid ExecuteQueryResponse: {response}") + return None diff --git a/google/cloud/bigtable/data/execute_query/_parameters_formatting.py b/google/cloud/bigtable/data/execute_query/_parameters_formatting.py new file mode 100644 index 000000000..edb7a6380 --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_parameters_formatting.py @@ -0,0 +1,118 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional +import datetime +from google.api_core.datetime_helpers import DatetimeWithNanoseconds +from google.cloud.bigtable.data.exceptions import ParameterTypeInferenceFailed +from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType +from google.cloud.bigtable.data.execute_query.metadata import SqlType + + +def _format_execute_query_params( + params: Optional[Dict[str, ExecuteQueryValueType]], + parameter_types: Optional[Dict[str, SqlType.Type]], +) -> Any: + """ + Takes a dictionary of param_name -> param_value and optionally parameter types. + If the parameters types are not provided, this function tries to infer them. + + Args: + params (Optional[Dict[str, ExecuteQueryValueType]]): mapping from parameter names + like they appear in query (without @ at the beginning) to their values. + Only values of type ExecuteQueryValueType are permitted. + parameter_types (Optional[Dict[str, SqlType.Type]]): mapping of parameter names + to their types. + + Raises: + ValueError: raised when parameter types cannot be inferred and were not + provided explicitly. + + Returns: + dictionary prasable to a protobuf represenging parameters as defined + in ExecuteQueryRequest.params + """ + if not params: + return {} + parameter_types = parameter_types or {} + + result_values = {} + + for key, value in params.items(): + user_provided_type = parameter_types.get(key) + try: + if user_provided_type: + if not isinstance(user_provided_type, SqlType.Type): + raise ValueError( + f"Parameter type for {key} should be provided as an instance of SqlType.Type subclass." + ) + param_type = user_provided_type + else: + param_type = _detect_type(value) + + value_pb_dict = _convert_value_to_pb_value_dict(value, param_type) + except ValueError as err: + raise ValueError(f"Error when parsing parameter {key}") from err + result_values[key] = value_pb_dict + + return result_values + + +def _convert_value_to_pb_value_dict( + value: ExecuteQueryValueType, param_type: SqlType.Type +) -> Any: + """ + Takes a value and converts it to a dictionary parsable to a protobuf. + + Args: + value (ExecuteQueryValueType): value + param_type (SqlType.Type): object describing which ExecuteQuery type the value represents. + + Returns: + dictionary parsable to a protobuf. + """ + # type field will be set only in top-level Value. + value_dict = param_type._to_value_pb_dict(value) + value_dict["type_"] = param_type._to_type_pb_dict() + return value_dict + + +_TYPES_TO_TYPE_DICTS = [ + (bytes, SqlType.Bytes()), + (str, SqlType.String()), + (bool, SqlType.Bool()), + (int, SqlType.Int64()), + (DatetimeWithNanoseconds, SqlType.Timestamp()), + (datetime.datetime, SqlType.Timestamp()), + (datetime.date, SqlType.Date()), +] + + +def _detect_type(value: ExecuteQueryValueType) -> SqlType.Type: + """ + Infers the ExecuteQuery type based on value. Raises error if type is amiguous. + raises ParameterTypeInferenceFailed if not possible. + """ + if value is None: + raise ParameterTypeInferenceFailed( + "Cannot infer type of None, please provide the type manually." + ) + + for field_type, type_dict in _TYPES_TO_TYPE_DICTS: + if isinstance(value, field_type): + return type_dict + + raise ParameterTypeInferenceFailed( + f"Cannot infer type of {type(value).__name__}, please provide the type manually." + ) diff --git a/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py b/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py new file mode 100644 index 000000000..b65dce27b --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_query_result_parsing_utils.py @@ -0,0 +1,133 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Type +from google.cloud.bigtable.data.execute_query.values import Struct +from google.cloud.bigtable.data.execute_query.metadata import SqlType +from google.cloud.bigtable_v2 import Value as PBValue +from google.api_core.datetime_helpers import DatetimeWithNanoseconds + +_REQUIRED_PROTO_FIELDS = { + SqlType.Bytes: "bytes_value", + SqlType.String: "string_value", + SqlType.Int64: "int_value", + SqlType.Float64: "float_value", + SqlType.Bool: "bool_value", + SqlType.Timestamp: "timestamp_value", + SqlType.Date: "date_value", + SqlType.Struct: "array_value", + SqlType.Array: "array_value", + SqlType.Map: "array_value", +} + + +def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> Any: + """ + used for parsing an array represented as a protobuf to a python list. + """ + return list( + map( + lambda val: _parse_pb_value_to_python_value( + val, metadata_type.element_type + ), + value.array_value.values, + ) + ) + + +def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any: + """ + used for parsing a map represented as a protobuf to a python dict. + + Values of type `Map` are stored in a `Value.array_value` where each entry + is another `Value.array_value` with two elements (the key and the value, + in that order). + Normally encoded Map values won't have repeated keys, however, the client + must handle the case in which they do. If the same key appears + multiple times, the _last_ value takes precedence. + """ + + try: + return dict( + map( + lambda map_entry: ( + _parse_pb_value_to_python_value( + map_entry.array_value.values[0], metadata_type.key_type + ), + _parse_pb_value_to_python_value( + map_entry.array_value.values[1], metadata_type.value_type + ), + ), + value.array_value.values, + ) + ) + except IndexError: + raise ValueError("Invalid map entry - less or more than two values.") + + +def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct: + """ + used for parsing a struct represented as a protobuf to a + google.cloud.bigtable.data.execute_query.Struct + """ + if len(value.array_value.values) != len(metadata_type.fields): + raise ValueError("Mismatched lengths of values and types.") + + struct = Struct() + for value, field in zip(value.array_value.values, metadata_type.fields): + field_name, field_type = field + struct.add_field(field_name, _parse_pb_value_to_python_value(value, field_type)) + + return struct + + +def _parse_timestamp_type( + value: PBValue, metadata_type: SqlType.Timestamp +) -> DatetimeWithNanoseconds: + """ + used for parsing a timestamp represented as a protobuf to DatetimeWithNanoseconds + """ + return DatetimeWithNanoseconds.from_timestamp_pb(value.timestamp_value) + + +_TYPE_PARSERS: Dict[Type[SqlType.Type], Callable[[PBValue, Any], Any]] = { + SqlType.Timestamp: _parse_timestamp_type, + SqlType.Struct: _parse_struct_type, + SqlType.Array: _parse_array_type, + SqlType.Map: _parse_map_type, +} + + +def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type) -> Any: + """ + used for converting the value represented as a protobufs to a python object. + """ + value_kind = value.WhichOneof("kind") + if not value_kind: + return None + + kind = type(metadata_type) + if not value.HasField(_REQUIRED_PROTO_FIELDS[kind]): + raise ValueError( + f"{_REQUIRED_PROTO_FIELDS[kind]} field for {kind.__name__} type not found in a Value." + ) + + if kind in _TYPE_PARSERS: + parser = _TYPE_PARSERS[kind] + return parser(value, metadata_type) + elif kind in _REQUIRED_PROTO_FIELDS: + field_name = _REQUIRED_PROTO_FIELDS[kind] + return getattr(value, field_name) + else: + raise ValueError(f"Unknown kind {kind}") diff --git a/google/cloud/bigtable/data/execute_query/_reader.py b/google/cloud/bigtable/data/execute_query/_reader.py new file mode 100644 index 000000000..9c0259cde --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_reader.py @@ -0,0 +1,149 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + TypeVar, + Generic, + Iterable, + Optional, + List, + Sequence, + cast, +) +from abc import ABC, abstractmethod + +from google.cloud.bigtable_v2 import ProtoRows, Value as PBValue +from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor + +from google.cloud.bigtable.data.execute_query._query_result_parsing_utils import ( + _parse_pb_value_to_python_value, +) + +from google.cloud.bigtable.helpers import batched + +from google.cloud.bigtable.data.execute_query.values import QueryResultRow +from google.cloud.bigtable.data.execute_query.metadata import ProtoMetadata + + +T = TypeVar("T") + + +class _Reader(ABC, Generic[T]): + """ + An interface for classes that consume and parse bytes returned by ``_ByteCursor``. + Parsed bytes should be gathered into bundles (rows or columns) of expected size + and converted to an appropriate type ``T`` that will be returned as a semantically + meaningful result to the library user by + :meth:`google.cloud.bigtable.instance.Instance.execute_query` or + :meth:`google.cloud.bigtable.data._async.client.BigtableDataClientAsync.execute_query` + methods. + + This class consumes data obtained externally to be usable in both sync and async clients. + + See :class:`google.cloud.bigtable.byte_cursor._ByteCursor` for more context. + """ + + @abstractmethod + def consume(self, bytes_to_consume: bytes) -> Optional[Iterable[T]]: + """This method receives a parsable chunk of bytes and returns either a None if there is + not enough chunks to return to the user yet (e.g. we haven't received all columns in a + row yet), or a list of appropriate values gathered from one or more parsable chunks. + + Args: + bytes_to_consume (bytes): chunk of parsable bytes received from + :meth:`google.cloud.bigtable.byte_cursor._ByteCursor.consume` + method. + + Returns: + Iterable[T] or None: Iterable if gathered values can form one or more instances of T, + or None if there is not enough data to construct at least one instance of T with + appropriate number of entries. + """ + raise NotImplementedError + + +class _QueryResultRowReader(_Reader[QueryResultRow]): + """ + A :class:`._Reader` consuming bytes representing + :class:`google.cloud.bigtable_v2.types.Type` + and producing :class:`google.cloud.bigtable.execute_query.QueryResultRow`. + + Number of entries in each row is determined by number of columns in + :class:`google.cloud.bigtable.execute_query.Metadata` obtained from + :class:`google.cloud.bigtable.byte_cursor._ByteCursor` passed in the constructor. + """ + + def __init__(self, byte_cursor: _ByteCursor[ProtoMetadata]): + """ + Constructs new instance of ``_QueryResultRowReader``. + + Args: + byte_cursor (google.cloud.bigtable.byte_cursor._ByteCursor): + byte_cursor that will be used to gather bytes for this instance of ``_Reader``, + needed to obtain :class:`google.cloud.bigtable.execute_query.Metadata` about + processed stream. + """ + self._values: List[PBValue] = [] + self._byte_cursor = byte_cursor + + @property + def _metadata(self) -> Optional[ProtoMetadata]: + return self._byte_cursor.metadata + + def _construct_query_result_row(self, values: Sequence[PBValue]) -> QueryResultRow: + result = QueryResultRow() + # The logic, not defined by mypy types, ensures that the value of + # "metadata" is never null at the time it is retrieved here + metadata = cast(ProtoMetadata, self._metadata) + columns = metadata.columns + + assert len(values) == len( + columns + ), "This function should be called only when count of values matches count of columns." + + for column, value in zip(columns, values): + parsed_value = _parse_pb_value_to_python_value(value, column.column_type) + result.add_field(column.column_name, parsed_value) + return result + + def _parse_proto_rows(self, bytes_to_parse: bytes) -> Iterable[PBValue]: + proto_rows = ProtoRows.pb().FromString(bytes_to_parse) + return proto_rows.values + + def consume(self, bytes_to_consume: bytes) -> Optional[Iterable[QueryResultRow]]: + if bytes_to_consume is None: + raise ValueError("bytes_to_consume shouldn't be None") + + self._values.extend(self._parse_proto_rows(bytes_to_consume)) + + # The logic, not defined by mypy types, ensures that the value of + # "metadata" is never null at the time it is retrieved here + num_columns = len(cast(ProtoMetadata, self._metadata).columns) + + if len(self._values) < num_columns: + return None + + rows = [] + for batch in batched(self._values, n=num_columns): + if len(batch) == num_columns: + rows.append(self._construct_query_result_row(batch)) + else: + raise ValueError( + "Server error, recieved bad number of values. " + f"Expected {num_columns} got {len(batch)}." + ) + + self._values = [] + + return rows diff --git a/google/cloud/bigtable/data/execute_query/metadata.py b/google/cloud/bigtable/data/execute_query/metadata.py new file mode 100644 index 000000000..98b94a644 --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/metadata.py @@ -0,0 +1,354 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides the SqlType class used for specifying types in +ExecuteQuery and some utilities. + +The SqlTypes are used in Metadata returned by the ExecuteQuery operation as well +as for specifying query parameter types explicitly. +""" + +from collections import defaultdict +from typing import ( + Optional, + List, + Dict, + Set, + Type, + Union, + Tuple, + Any, +) +from google.cloud.bigtable.data.execute_query.values import _NamedList +from google.cloud.bigtable_v2 import ResultSetMetadata +from google.cloud.bigtable_v2 import Type as PBType +from google.type import date_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.api_core.datetime_helpers import DatetimeWithNanoseconds +import datetime + + +class SqlType: + """ + Classes denoting types of values returned by Bigtable's ExecuteQuery operation. + + Used in :class:`.Metadata`. + """ + + class Type: + expected_type: Optional[type] = None + value_pb_dict_field_name: Optional[str] = None + type_field_name: Optional[str] = None + + @classmethod + def from_pb_type(cls, pb_type: Optional[PBType] = None): + return cls() + + def _to_type_pb_dict(self) -> Dict[str, Any]: + if not self.type_field_name: + raise NotImplementedError( + "Fill in expected_type and value_pb_dict_field_name" + ) + + return {self.type_field_name: {}} + + def _to_value_pb_dict(self, value: Any) -> Dict[str, Any]: + if self.expected_type is None or self.value_pb_dict_field_name is None: + raise NotImplementedError( + "Fill in expected_type and value_pb_dict_field_name" + ) + + if value is None: + return {} + + if not isinstance(value, self.expected_type): + raise ValueError( + f"Expected query parameter of type {self.expected_type.__name__}, got {type(value).__name__}" + ) + + return {self.value_pb_dict_field_name: value} + + def __eq__(self, other): + return isinstance(other, type(self)) + + def __str__(self) -> str: + return self.__class__.__name__ + + def __repr__(self) -> str: + return self.__str__() + + class Struct(_NamedList[Type], Type): + @classmethod + def from_pb_type(cls, type_pb: Optional[PBType] = None) -> "SqlType.Struct": + if type_pb is None: + raise ValueError("missing required argument type_pb") + fields: List[Tuple[Optional[str], SqlType.Type]] = [] + for field in type_pb.struct_type.fields: + fields.append((field.field_name, _pb_type_to_metadata_type(field.type))) + return cls(fields) + + def _to_value_pb_dict(self, value: Any): + raise NotImplementedError("Struct is not supported as a query parameter") + + def _to_type_pb_dict(self) -> Dict[str, Any]: + raise NotImplementedError("Struct is not supported as a query parameter") + + def __eq__(self, other: object): + # Cannot use super() here - we'd either have to: + # - call super() in these base classes, which would in turn call Object.__eq__ + # to compare objects by identity and return a False, or + # - do not call super() in these base classes, which would result in calling only + # one of the __eq__ methods (a super() in the base class would be required to call the other one), or + # - call super() in only one of the base classes, but that would be error prone and changing + # the order of base classes would introduce unexpected behaviour. + # we also have to disable mypy because it doesn't see that SqlType.Struct == _NamedList[Type] + return SqlType.Type.__eq__(self, other) and _NamedList.__eq__(self, other) # type: ignore + + def __str__(self): + return super(_NamedList, self).__str__() + + class Array(Type): + def __init__(self, element_type: "SqlType.Type"): + if isinstance(element_type, SqlType.Array): + raise ValueError("Arrays of arrays are not supported.") + self._element_type = element_type + + @property + def element_type(self): + return self._element_type + + @classmethod + def from_pb_type(cls, type_pb: Optional[PBType] = None) -> "SqlType.Array": + if type_pb is None: + raise ValueError("missing required argument type_pb") + return cls(_pb_type_to_metadata_type(type_pb.array_type.element_type)) + + def _to_value_pb_dict(self, value: Any): + raise NotImplementedError("Array is not supported as a query parameter") + + def _to_type_pb_dict(self) -> Dict[str, Any]: + raise NotImplementedError("Array is not supported as a query parameter") + + def __eq__(self, other): + return super().__eq__(other) and self.element_type == other.element_type + + def __str__(self) -> str: + return f"{self.__class__.__name__}<{str(self.element_type)}>" + + class Map(Type): + def __init__(self, key_type: "SqlType.Type", value_type: "SqlType.Type"): + self._key_type = key_type + self._value_type = value_type + + @property + def key_type(self): + return self._key_type + + @property + def value_type(self): + return self._value_type + + @classmethod + def from_pb_type(cls, type_pb: Optional[PBType] = None) -> "SqlType.Map": + if type_pb is None: + raise ValueError("missing required argument type_pb") + return cls( + _pb_type_to_metadata_type(type_pb.map_type.key_type), + _pb_type_to_metadata_type(type_pb.map_type.value_type), + ) + + def _to_type_pb_dict(self) -> Dict[str, Any]: + raise NotImplementedError("Map is not supported as a query parameter") + + def _to_value_pb_dict(self, value: Any): + raise NotImplementedError("Map is not supported as a query parameter") + + def __eq__(self, other): + return ( + super().__eq__(other) + and self.key_type == other.key_type + and self.value_type == other.value_type + ) + + def __str__(self) -> str: + return ( + f"{self.__class__.__name__}<" + f"{str(self._key_type)},{str(self._value_type)}>" + ) + + class Bytes(Type): + expected_type = bytes + value_pb_dict_field_name = "bytes_value" + type_field_name = "bytes_type" + + class String(Type): + expected_type = str + value_pb_dict_field_name = "string_value" + type_field_name = "string_type" + + class Int64(Type): + expected_type = int + value_pb_dict_field_name = "int_value" + type_field_name = "int64_type" + + class Float64(Type): + expected_type = float + value_pb_dict_field_name = "float_value" + type_field_name = "float64_type" + + class Bool(Type): + expected_type = bool + value_pb_dict_field_name = "bool_value" + type_field_name = "bool_type" + + class Timestamp(Type): + type_field_name = "timestamp_type" + expected_types = ( + datetime.datetime, + DatetimeWithNanoseconds, + ) + + def _to_value_pb_dict(self, value: Any) -> Dict[str, Any]: + if value is None: + return {} + + if not isinstance(value, self.expected_types): + raise ValueError( + f"Expected one of {', '.join((_type.__name__ for _type in self.expected_types))}" + ) + + if isinstance(value, DatetimeWithNanoseconds): + return {"timestamp_value": value.timestamp_pb()} + else: # value must be an instance of datetime.datetime + ts = timestamp_pb2.Timestamp() + ts.FromDatetime(value) + return {"timestamp_value": ts} + + class Date(Type): + type_field_name = "date_type" + expected_type = datetime.date + + def _to_value_pb_dict(self, value: Any) -> Dict[str, Any]: + if value is None: + return {} + + if not isinstance(value, self.expected_type): + raise ValueError( + f"Expected query parameter of type {self.expected_type.__name__}, got {type(value).__name__}" + ) + + return { + "date_value": date_pb2.Date( + year=value.year, + month=value.month, + day=value.day, + ) + } + + +class Metadata: + pass + + +class ProtoMetadata(Metadata): + class Column: + def __init__(self, column_name: Optional[str], column_type: SqlType.Type): + self._column_name = column_name + self._column_type = column_type + + @property + def column_name(self) -> Optional[str]: + return self._column_name + + @property + def column_type(self) -> SqlType.Type: + return self._column_type + + @property + def columns(self) -> List[Column]: + return self._columns + + def __init__( + self, columns: Optional[List[Tuple[Optional[str], SqlType.Type]]] = None + ): + self._columns: List[ProtoMetadata.Column] = [] + self._column_indexes: Dict[str, List[int]] = defaultdict(list) + self._duplicate_names: Set[str] = set() + + if columns: + for column_name, column_type in columns: + if column_name is not None: + if column_name in self._column_indexes: + self._duplicate_names.add(column_name) + self._column_indexes[column_name].append(len(self._columns)) + self._columns.append(ProtoMetadata.Column(column_name, column_type)) + + def __getitem__(self, index_or_name: Union[str, int]) -> Column: + if isinstance(index_or_name, str): + if index_or_name in self._duplicate_names: + raise KeyError( + f"Ambigious column name: '{index_or_name}', use index instead." + f" Field present on indexes {', '.join(map(str, self._column_indexes[index_or_name]))}." + ) + if index_or_name not in self._column_indexes: + raise KeyError(f"No such column: {index_or_name}") + index = self._column_indexes[index_or_name][0] + else: + index = index_or_name + return self._columns[index] + + def __len__(self): + return len(self._columns) + + def __str__(self) -> str: + columns_str = ", ".join([str(column) for column in self._columns]) + return f"{self.__class__.__name__}([{columns_str}])" + + def __repr__(self) -> str: + return self.__str__() + + +def _pb_metadata_to_metadata_types( + metadata_pb: ResultSetMetadata, +) -> Metadata: + if "proto_schema" in metadata_pb: + fields: List[Tuple[Optional[str], SqlType.Type]] = [] + for column_metadata in metadata_pb.proto_schema.columns: + fields.append( + (column_metadata.name, _pb_type_to_metadata_type(column_metadata.type)) + ) + return ProtoMetadata(fields) + raise ValueError("Invalid ResultSetMetadata object received.") + + +_PROTO_TYPE_TO_METADATA_TYPE_FACTORY: Dict[str, Type[SqlType.Type]] = { + "bytes_type": SqlType.Bytes, + "string_type": SqlType.String, + "int64_type": SqlType.Int64, + "float64_type": SqlType.Float64, + "bool_type": SqlType.Bool, + "timestamp_type": SqlType.Timestamp, + "date_type": SqlType.Date, + "struct_type": SqlType.Struct, + "array_type": SqlType.Array, + "map_type": SqlType.Map, +} + + +def _pb_type_to_metadata_type(type_pb: PBType) -> SqlType.Type: + kind = PBType.pb(type_pb).WhichOneof("kind") + if kind in _PROTO_TYPE_TO_METADATA_TYPE_FACTORY: + return _PROTO_TYPE_TO_METADATA_TYPE_FACTORY[kind].from_pb_type(type_pb) + raise ValueError(f"Unrecognized response data type: {type_pb}") diff --git a/google/cloud/bigtable/data/execute_query/values.py b/google/cloud/bigtable/data/execute_query/values.py new file mode 100644 index 000000000..450f6f855 --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/values.py @@ -0,0 +1,116 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import ( + Optional, + List, + Dict, + Set, + Union, + TypeVar, + Generic, + Tuple, + Mapping, +) +from google.type import date_pb2 # type: ignore +from google.api_core.datetime_helpers import DatetimeWithNanoseconds + +T = TypeVar("T") + + +class _NamedList(Generic[T]): + """ + A class designed to store a list of elements, which can be accessed by + name or index. + This class is different from namedtuple, because namedtuple has some + restrictions on names of fields and we do not want to have them. + """ + + _str_cls_name = "_NamedList" + + def __init__(self, fields: Optional[List[Tuple[Optional[str], T]]] = None): + self._fields: List[Tuple[Optional[str], T]] = [] + self._field_indexes: Dict[str, List[int]] = defaultdict(list) + self._duplicate_names: Set[str] = set() + + if fields: + for field_name, field_type in fields: + self.add_field(field_name, field_type) + + def add_field(self, name: Optional[str], value: T): + if name: + if name in self._field_indexes: + self._duplicate_names.add(name) + self._field_indexes[name].append(len(self._fields)) + self._fields.append((name, value)) + + @property + def fields(self): + return self._fields + + def __getitem__(self, index_or_name: Union[str, int]): + if isinstance(index_or_name, str): + if index_or_name in self._duplicate_names: + raise KeyError( + f"Ambigious field name: '{index_or_name}', use index instead." + f" Field present on indexes {', '.join(map(str, self._field_indexes[index_or_name]))}." + ) + if index_or_name not in self._field_indexes: + raise KeyError(f"No such field: {index_or_name}") + index = self._field_indexes[index_or_name][0] + else: + index = index_or_name + return self._fields[index][1] + + def __len__(self): + return len(self._fields) + + def __eq__(self, other): + if not isinstance(other, _NamedList): + return False + + return ( + self._fields == other._fields + and self._field_indexes == other._field_indexes + ) + + def __str__(self) -> str: + fields_str = ", ".join([str(field) for field in self._fields]) + return f"{self.__class__.__name__}([{fields_str}])" + + def __repr__(self) -> str: + return self.__str__() + + +ExecuteQueryValueType = Union[ + int, + float, + bool, + bytes, + str, + DatetimeWithNanoseconds, + date_pb2.Date, + "Struct", + List["ExecuteQueryValueType"], + Mapping[Union[str, int, bytes], "ExecuteQueryValueType"], +] + + +class QueryResultRow(_NamedList[ExecuteQueryValueType]): + pass + + +class Struct(_NamedList[ExecuteQueryValueType]): + pass diff --git a/google/cloud/bigtable/helpers.py b/google/cloud/bigtable/helpers.py new file mode 100644 index 000000000..78af43089 --- /dev/null +++ b/google/cloud/bigtable/helpers.py @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypeVar, Iterable, Generator, Tuple + +from itertools import islice + +T = TypeVar("T") + + +# batched landed in standard library in Python 3.11. +def batched(iterable: Iterable[T], n) -> Generator[Tuple[T, ...], None, None]: + # batched('ABCDEFG', 3) → ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + batch = tuple(islice(it, n)) + while batch: + yield batch + batch = tuple(islice(it, n)) diff --git a/google/cloud/bigtable/instance.py b/google/cloud/bigtable/instance.py index 6d092cefd..23fb1c95d 100644 --- a/google/cloud/bigtable/instance.py +++ b/google/cloud/bigtable/instance.py @@ -32,6 +32,7 @@ import warnings + _INSTANCE_NAME_RE = re.compile( r"^projects/(?P[^/]+)/" r"instances/(?P[a-z][-a-z0-9]*)$" ) diff --git a/samples/snippets/data_client/data_client_snippets_async.py b/samples/snippets/data_client/data_client_snippets_async.py index 742e7cb8e..dabbcb839 100644 --- a/samples/snippets/data_client/data_client_snippets_async.py +++ b/samples/snippets/data_client/data_client_snippets_async.py @@ -69,7 +69,10 @@ async def write_batch(project_id, instance_id, table_id): for sub_exception in e.exceptions: failed_entry: RowMutationEntry = sub_exception.entry cause: Exception = sub_exception.__cause__ - print(f"Failed mutation: {failed_entry.row_key} with error: {cause!r}") + print( + f"Failed mutation: {failed_entry.row_key} with error: {cause!r}" + ) + # [END bigtable_async_writes_batch] await write_batch(table.client.project, table.instance_id, table.table_id) @@ -94,6 +97,7 @@ async def write_increment(project_id, instance_id, table_id): # check result cell = result_row[0] print(f"{cell.row_key} value: {int(cell)}") + # [END bigtable_async_write_increment] await write_increment(table.client.project, table.instance_id, table.table_id) @@ -127,6 +131,7 @@ async def write_conditional(project_id, instance_id, table_id): ) if result is True: print("The row os_name was set to android") + # [END bigtable_async_writes_conditional] await write_conditional(table.client.project, table.instance_id, table.table_id) @@ -141,6 +146,7 @@ async def read_row(project_id, instance_id, table_id): row_key = "phone#4c410523#20190501" row = await table.read_row(row_key) print(row) + # [END bigtable_async_reads_row] await read_row(table.client.project, table.instance_id, table.table_id) @@ -158,6 +164,7 @@ async def read_row_partial(project_id, instance_id, table_id): row = await table.read_row(row_key, row_filter=col_filter) print(row) + # [END bigtable_async_reads_row_partial] await read_row_partial(table.client.project, table.instance_id, table.table_id) @@ -171,10 +178,9 @@ async def read_rows(project_id, instance_id, table_id): async with BigtableDataClientAsync(project=project_id) as client: async with client.get_table(instance_id, table_id) as table: - query = ReadRowsQuery(row_keys=[ - b"phone#4c410523#20190501", - b"phone#4c410523#20190502" - ]) + query = ReadRowsQuery( + row_keys=[b"phone#4c410523#20190501", b"phone#4c410523#20190502"] + ) async for row in await table.read_rows_stream(query): print(row) @@ -194,12 +200,13 @@ async def read_row_range(project_id, instance_id, table_id): row_range = RowRange( start_key=b"phone#4c410523#20190501", - end_key=b"phone#4c410523#201906201" + end_key=b"phone#4c410523#201906201", ) query = ReadRowsQuery(row_ranges=[row_range]) async for row in await table.read_rows_stream(query): print(row) + # [END bigtable_async_reads_row_range] await read_row_range(table.client.project, table.instance_id, table.table_id) @@ -221,6 +228,7 @@ async def read_prefix(project_id, instance_id, table_id): async for row in await table.read_rows_stream(query): print(row) + # [END bigtable_async_reads_prefix] await read_prefix(table.client.project, table.instance_id, table.table_id) @@ -240,5 +248,30 @@ async def read_with_filter(project_id, instance_id, table_id): async for row in await table.read_rows_stream(query): print(row) + # [END bigtable_async_reads_filter] await read_with_filter(table.client.project, table.instance_id, table.table_id) + + +async def execute_query(table): + # [START bigtable_async_execute_query] + from google.cloud.bigtable.data import BigtableDataClientAsync + + async def execute_query(project_id, instance_id, table_id): + async with BigtableDataClientAsync(project=project_id) as client: + query = ( + "SELECT _key, stats_summary['os_build'], " + "stats_summary['connected_cell'], " + "stats_summary['connected_wifi'] " + f"from `{table_id}` WHERE _key=@row_key" + ) + result = await client.execute_query( + query, + instance_id, + parameters={"row_key": b"phone#4c410523#20190501"}, + ) + results = [r async for r in result] + print(results) + + # [END bigtable_async_execute_query] + await execute_query(table.client.project, table.instance_id, table.table_id) diff --git a/samples/snippets/data_client/data_client_snippets_async_test.py b/samples/snippets/data_client/data_client_snippets_async_test.py index d9968e6dc..2e0fb9b81 100644 --- a/samples/snippets/data_client/data_client_snippets_async_test.py +++ b/samples/snippets/data_client/data_client_snippets_async_test.py @@ -101,3 +101,8 @@ async def test_read_with_prefix(table): @pytest.mark.asyncio async def test_read_with_filter(table): await data_snippets.read_with_filter(table) + + +@pytest.mark.asyncio +async def test_execute_query(table): + await data_snippets.execute_query(table) diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index fa7c56db1..5ed0c2fb9 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -12,3 +12,4 @@ grpc-google-iam-v1==0.12.4 proto-plus==1.22.3 libcst==0.2.5 protobuf==3.20.2 +pytest-asyncio==0.21.2 diff --git a/tests/_testing.py b/tests/_testing.py new file mode 100644 index 000000000..81cce7b78 --- /dev/null +++ b/tests/_testing.py @@ -0,0 +1,36 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue + + +TYPE_INT = { + "int64_type": { + "encoding": {"big_endian_bytes": {"bytes_type": {"encoding": {"raw": {}}}}} + } +} + + +def proto_rows_bytes(*args): + return ProtoRows.serialize(ProtoRows(values=[PBValue(**arg) for arg in args])) + + +def split_bytes_into_chunks(bytes_to_split, num_chunks): + from google.cloud.bigtable.helpers import batched + + assert num_chunks <= len(bytes_to_split) + bytes_per_part = (len(bytes_to_split) - 1) // num_chunks + 1 + result = list(map(bytes, batched(bytes_to_split, bytes_per_part))) + assert len(result) == num_chunks + return result diff --git a/tests/system/data/test_execute_query_async.py b/tests/system/data/test_execute_query_async.py new file mode 100644 index 000000000..a680d2de0 --- /dev/null +++ b/tests/system/data/test_execute_query_async.py @@ -0,0 +1,288 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import os +from unittest import mock +from .test_execute_query_utils import ( + ChannelMockAsync, + response_with_metadata, + response_with_result, +) +from google.api_core import exceptions as core_exceptions +from google.cloud.bigtable.data import BigtableDataClientAsync +import google.cloud.bigtable.data._async.client + +TABLE_NAME = "TABLE_NAME" +INSTANCE_NAME = "INSTANCE_NAME" + + +class TestAsyncExecuteQuery: + @pytest.fixture() + def async_channel_mock(self): + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + yield ChannelMockAsync() + + @pytest.fixture() + def async_client(self, async_channel_mock): + with mock.patch.dict( + os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"} + ), mock.patch.object( + google.cloud.bigtable.data._async.client, + "PooledChannel", + return_value=async_channel_mock, + ): + yield BigtableDataClientAsync() + + @pytest.mark.asyncio + async def test_execute_query(self, async_client, async_channel_mock): + values = [ + response_with_metadata(), + response_with_result("test"), + response_with_result(8, resume_token=b"r1"), + response_with_result("test2"), + response_with_result(9, resume_token=b"r2"), + response_with_result("test3"), + response_with_result(None, resume_token=b"r3"), + ] + async_channel_mock.set_values(values) + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert len(async_channel_mock.execute_query_calls) == 1 + + @pytest.mark.asyncio + async def test_execute_query_with_params(self, async_client, async_channel_mock): + values = [ + response_with_metadata(), + response_with_result("test2"), + response_with_result(9, resume_token=b"r2"), + ] + async_channel_mock.set_values(values) + + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME} WHERE b=@b", + INSTANCE_NAME, + parameters={"b": 9}, + ) + results = [r async for r in result] + assert len(results) == 1 + assert results[0]["a"] == "test2" + assert results[0]["b"] == 9 + assert len(async_channel_mock.execute_query_calls) == 1 + + @pytest.mark.asyncio + async def test_execute_query_error_before_metadata( + self, async_client, async_channel_mock + ): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + DeadlineExceeded(""), + response_with_metadata(), + response_with_result("test"), + response_with_result(8, resume_token=b"r1"), + response_with_result("test2"), + response_with_result(9, resume_token=b"r2"), + response_with_result("test3"), + response_with_result(None, resume_token=b"r3"), + ] + async_channel_mock.set_values(values) + + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 3 + assert len(async_channel_mock.execute_query_calls) == 2 + + @pytest.mark.asyncio + async def test_execute_query_error_after_metadata( + self, async_client, async_channel_mock + ): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + response_with_metadata(), + DeadlineExceeded(""), + response_with_metadata(), + response_with_result("test"), + response_with_result(8, resume_token=b"r1"), + response_with_result("test2"), + response_with_result(9, resume_token=b"r2"), + response_with_result("test3"), + response_with_result(None, resume_token=b"r3"), + ] + async_channel_mock.set_values(values) + + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 3 + assert len(async_channel_mock.execute_query_calls) == 2 + assert async_channel_mock.resume_tokens == [] + + @pytest.mark.asyncio + async def test_execute_query_with_retries(self, async_client, async_channel_mock): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + response_with_metadata(), + response_with_result("test"), + response_with_result(8, resume_token=b"r1"), + DeadlineExceeded(""), + response_with_result("test2"), + response_with_result(9, resume_token=b"r2"), + response_with_result("test3"), + DeadlineExceeded(""), + response_with_result("test3"), + response_with_result(None, resume_token=b"r3"), + ] + async_channel_mock.set_values(values) + + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert len(async_channel_mock.execute_query_calls) == 3 + assert async_channel_mock.resume_tokens == [b"r1", b"r2"] + + @pytest.mark.parametrize( + "exception", + [ + (core_exceptions.DeadlineExceeded("")), + (core_exceptions.Aborted("")), + (core_exceptions.ServiceUnavailable("")), + ], + ) + @pytest.mark.asyncio + async def test_execute_query_retryable_error( + self, async_client, async_channel_mock, exception + ): + values = [ + response_with_metadata(), + response_with_result("test", resume_token=b"t1"), + exception, + response_with_result(8, resume_token=b"t2"), + ] + async_channel_mock.set_values(values) + + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 1 + assert len(async_channel_mock.execute_query_calls) == 2 + assert async_channel_mock.resume_tokens == [b"t1"] + + @pytest.mark.asyncio + async def test_execute_query_retry_partial_row( + self, async_client, async_channel_mock + ): + values = [ + response_with_metadata(), + response_with_result("test", resume_token=b"t1"), + core_exceptions.DeadlineExceeded(""), + response_with_result(8, resume_token=b"t2"), + ] + async_channel_mock.set_values(values) + + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert len(async_channel_mock.execute_query_calls) == 2 + assert async_channel_mock.resume_tokens == [b"t1"] + + @pytest.mark.parametrize( + "ExceptionType", + [ + (core_exceptions.InvalidArgument), + (core_exceptions.FailedPrecondition), + (core_exceptions.PermissionDenied), + (core_exceptions.MethodNotImplemented), + (core_exceptions.Cancelled), + (core_exceptions.AlreadyExists), + (core_exceptions.OutOfRange), + (core_exceptions.DataLoss), + (core_exceptions.Unauthenticated), + (core_exceptions.NotFound), + (core_exceptions.ResourceExhausted), + (core_exceptions.Unknown), + (core_exceptions.InternalServerError), + ], + ) + @pytest.mark.asyncio + async def test_execute_query_non_retryable( + self, async_client, async_channel_mock, ExceptionType + ): + values = [ + response_with_metadata(), + response_with_result("test"), + response_with_result(8, resume_token=b"r1"), + ExceptionType(""), + response_with_result("test2"), + response_with_result(9, resume_token=b"r2"), + response_with_result("test3"), + response_with_result(None, resume_token=b"r3"), + ] + async_channel_mock.set_values(values) + + result = await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + r = await result.__anext__() + assert r["a"] == "test" + assert r["b"] == 8 + + with pytest.raises(ExceptionType): + r = await result.__anext__() + + assert len(async_channel_mock.execute_query_calls) == 1 + assert async_channel_mock.resume_tokens == [] + + @pytest.mark.asyncio + async def test_execute_query_metadata_received_multiple_times_detected( + self, async_client, async_channel_mock + ): + values = [ + response_with_metadata(), + response_with_metadata(), + ] + async_channel_mock.set_values(values) + + with pytest.raises(Exception, match="Invalid ExecuteQuery response received"): + [ + r + async for r in await async_client.execute_query( + f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME + ) + ] diff --git a/tests/system/data/test_execute_query_utils.py b/tests/system/data/test_execute_query_utils.py new file mode 100644 index 000000000..9e27b95f2 --- /dev/null +++ b/tests/system/data/test_execute_query_utils.py @@ -0,0 +1,272 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio as pga +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse +from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue +import grpc.aio + + +try: + # async mock for python3.7-10 + from asyncio import coroutine + + def async_mock(return_value=None): + coro = mock.Mock(name="CoroutineResult") + corofunc = mock.Mock(name="CoroutineFunction", side_effect=coroutine(coro)) + corofunc.coro = coro + corofunc.coro.return_value = return_value + return corofunc + +except ImportError: + # async mock for python3.11 or later + from unittest.mock import AsyncMock + + def async_mock(return_value=None): + return AsyncMock(return_value=return_value) + + +# ExecuteQueryResponse( +# metadata={ +# "proto_schema": { +# "columns": [ +# {"name": "test1", "type_": TYPE_INT}, +# {"name": "test2", "type_": TYPE_INT}, +# ] +# } +# } +# ), +# ExecuteQueryResponse( +# results={"proto_rows_batch": {"batch_data": messages[0]}} +# ), + + +def response_with_metadata(): + schema = {"a": "string_type", "b": "int64_type"} + return ExecuteQueryResponse( + { + "metadata": { + "proto_schema": { + "columns": [ + {"name": name, "type_": {_type: {}}} + for name, _type in schema.items() + ] + } + } + } + ) + + +def response_with_result(*args, resume_token=None): + if resume_token is None: + resume_token_dict = {} + else: + resume_token_dict = {"resume_token": resume_token} + + values = [] + for column_value in args: + if column_value is None: + pb_value = PBValue({}) + else: + pb_value = PBValue( + { + "int_value" + if isinstance(column_value, int) + else "string_value": column_value + } + ) + values.append(pb_value) + rows = ProtoRows(values=values) + + return ExecuteQueryResponse( + { + "results": { + "proto_rows_batch": { + "batch_data": ProtoRows.serialize(rows), + }, + **resume_token_dict, + } + } + ) + + +class ExecuteQueryStreamMock: + def __init__(self, parent): + self.parent = parent + self.iter = iter(self.parent.values) + + def __call__(self, *args, **kwargs): + request = args[0] + + self.parent.execute_query_calls.append(request) + if request.resume_token: + self.parent.resume_tokens.append(request.resume_token) + + def stream(): + for value in self.iter: + if isinstance(value, Exception): + raise value + else: + yield value + + return stream() + + +class ChannelMock: + def __init__(self): + self.execute_query_calls = [] + self.values = [] + self.resume_tokens = [] + + def set_values(self, values): + self.values = values + + def unary_unary(self, *args, **kwargs): + return mock.MagicMock() + + def unary_stream(self, *args, **kwargs): + if args[0] == "/google.bigtable.v2.Bigtable/ExecuteQuery": + return ExecuteQueryStreamMock(self) + return mock.MagicMock() + + +class ChannelMockAsync(pga.PooledChannel, mock.MagicMock): + def __init__(self, *args, **kwargs): + mock.MagicMock.__init__(self, *args, **kwargs) + self.execute_query_calls = [] + self.values = [] + self.resume_tokens = [] + self._iter = [] + + def get_async_get(self, *args, **kwargs): + return self.async_gen + + def set_values(self, values): + self.values = values + self._iter = iter(self.values) + + def unary_unary(self, *args, **kwargs): + return async_mock() + + def unary_stream(self, *args, **kwargs): + if args[0] == "/google.bigtable.v2.Bigtable/ExecuteQuery": + + async def async_gen(*args, **kwargs): + for value in self._iter: + yield value + + iter = async_gen() + + class UnaryStreamCallMock(grpc.aio.UnaryStreamCall): + def __aiter__(self): + async def _impl(*args, **kwargs): + try: + while True: + yield await self.read() + except StopAsyncIteration: + pass + + return _impl() + + async def read(self): + value = await iter.__anext__() + if isinstance(value, Exception): + raise value + return value + + def add_done_callback(*args, **kwargs): + pass + + def cancel(*args, **kwargs): + pass + + def cancelled(*args, **kwargs): + pass + + def code(*args, **kwargs): + pass + + def details(*args, **kwargs): + pass + + def done(*args, **kwargs): + pass + + def initial_metadata(*args, **kwargs): + pass + + def time_remaining(*args, **kwargs): + pass + + def trailing_metadata(*args, **kwargs): + pass + + async def wait_for_connection(*args, **kwargs): + return async_mock() + + class UnaryStreamMultiCallableMock(grpc.aio.UnaryStreamMultiCallable): + def __init__(self, parent): + self.parent = parent + + def __call__( + self, + request, + *, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None + ): + self.parent.execute_query_calls.append(request) + if request.resume_token: + self.parent.resume_tokens.append(request.resume_token) + return UnaryStreamCallMock() + + def add_done_callback(*args, **kwargs): + pass + + def cancel(*args, **kwargs): + pass + + def cancelled(*args, **kwargs): + pass + + def code(*args, **kwargs): + pass + + def details(*args, **kwargs): + pass + + def done(*args, **kwargs): + pass + + def initial_metadata(*args, **kwargs): + pass + + def time_remaining(*args, **kwargs): + pass + + def trailing_metadata(*args, **kwargs): + pass + + def wait_for_connection(*args, **kwargs): + pass + + # unary_stream should return https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.UnaryStreamMultiCallable + # PTAL https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.Channel.unary_stream + return UnaryStreamMultiCallableMock(self) + return async_mock() diff --git a/tests/unit/_testing.py b/tests/unit/_testing.py new file mode 100644 index 000000000..e0d8d2a22 --- /dev/null +++ b/tests/unit/_testing.py @@ -0,0 +1,16 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes diff --git a/tests/unit/data/_async/__init__.py b/tests/unit/data/_async/__init__.py new file mode 100644 index 000000000..6d5e14bcf --- /dev/null +++ b/tests/unit/data/_async/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/data/_testing.py b/tests/unit/data/_testing.py new file mode 100644 index 000000000..b5dd3f444 --- /dev/null +++ b/tests/unit/data/_testing.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from unittest.mock import Mock +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes diff --git a/tests/unit/data/execute_query/__init__.py b/tests/unit/data/execute_query/__init__.py new file mode 100644 index 000000000..6d5e14bcf --- /dev/null +++ b/tests/unit/data/execute_query/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/data/execute_query/_async/__init__.py b/tests/unit/data/execute_query/_async/__init__.py new file mode 100644 index 000000000..6d5e14bcf --- /dev/null +++ b/tests/unit/data/execute_query/_async/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/data/execute_query/_async/_testing.py b/tests/unit/data/execute_query/_async/_testing.py new file mode 100644 index 000000000..5a7acbdd9 --- /dev/null +++ b/tests/unit/data/execute_query/_async/_testing.py @@ -0,0 +1,36 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes + + +try: + # async mock for python3.7-10 + from unittest.mock import Mock + from asyncio import coroutine + + def async_mock(return_value=None): + coro = Mock(name="CoroutineResult") + corofunc = Mock(name="CoroutineFunction", side_effect=coroutine(coro)) + corofunc.coro = coro + corofunc.coro.return_value = return_value + return corofunc + +except ImportError: + # async mock for python3.11 or later + from unittest.mock import AsyncMock + + def async_mock(return_value=None): + return AsyncMock(return_value=return_value) diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py new file mode 100644 index 000000000..5c577ed74 --- /dev/null +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import Mock +from mock import patch +import pytest +from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( + ExecuteQueryIteratorAsync, +) +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse +from ._testing import TYPE_INT, proto_rows_bytes, split_bytes_into_chunks, async_mock + + +class MockIteratorAsync: + def __init__(self, values, delay=None): + self._values = values + self.idx = 0 + self._delay = delay + + def __aiter__(self): + return self + + async def __anext__(self): + if self.idx >= len(self._values): + raise StopAsyncIteration + if self._delay is not None: + await asyncio.sleep(self._delay) + value = self._values[self.idx] + self.idx += 1 + return value + + +@pytest.fixture +def proto_byte_stream(): + proto_rows = [ + proto_rows_bytes({"int_value": 1}, {"int_value": 2}), + proto_rows_bytes({"int_value": 3}, {"int_value": 4}), + proto_rows_bytes({"int_value": 5}, {"int_value": 6}), + ] + + messages = [ + *split_bytes_into_chunks(proto_rows[0], num_chunks=2), + *split_bytes_into_chunks(proto_rows[1], num_chunks=3), + proto_rows[2], + ] + + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": { + "columns": [ + {"name": "test1", "type_": TYPE_INT}, + {"name": "test2", "type_": TYPE_INT}, + ] + } + } + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[0]}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[1]}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[2]}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[3]}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[4]}, + "resume_token": b"token2", + } + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[5]}, + "resume_token": b"token3", + } + ), + ] + return stream + + +@pytest.mark.asyncio +async def test_iterator(proto_byte_stream): + client_mock = Mock() + + client_mock._register_instance = async_mock() + client_mock._remove_instance_registration = async_mock() + mock_async_iterator = MockIteratorAsync(proto_byte_stream) + iterator = None + + with patch( + "google.api_core.retry.retry_target_stream_async", + return_value=mock_async_iterator, + ): + iterator = ExecuteQueryIteratorAsync( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + result = [] + async for value in iterator: + result.append(tuple(value)) + assert result == [(1, 2), (3, 4), (5, 6)] + + assert iterator.is_closed + client_mock._register_instance.assert_called_once() + client_mock._remove_instance_registration.assert_called_once() + + assert mock_async_iterator.idx == len(proto_byte_stream) + + +@pytest.mark.asyncio +async def test_iterator_awaits_metadata(proto_byte_stream): + client_mock = Mock() + + client_mock._register_instance = async_mock() + client_mock._remove_instance_registration = async_mock() + mock_async_iterator = MockIteratorAsync(proto_byte_stream) + iterator = None + with patch( + "google.api_core.retry.retry_target_stream_async", + return_value=mock_async_iterator, + ): + iterator = ExecuteQueryIteratorAsync( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + + await iterator.metadata() + + assert mock_async_iterator.idx == 1 diff --git a/tests/unit/data/execute_query/_testing.py b/tests/unit/data/execute_query/_testing.py new file mode 100644 index 000000000..9d24eee34 --- /dev/null +++ b/tests/unit/data/execute_query/_testing.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes diff --git a/tests/unit/data/execute_query/test_byte_cursor.py b/tests/unit/data/execute_query/test_byte_cursor.py new file mode 100644 index 000000000..e283e1ca2 --- /dev/null +++ b/tests/unit/data/execute_query/test_byte_cursor.py @@ -0,0 +1,149 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse +from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor + +from ._testing import TYPE_INT + + +def pass_values_to_byte_cursor(byte_cursor, iterable): + for value in iterable: + result = byte_cursor.consume(value) + if result is not None: + yield result + + +class TestByteCursor: + def test__proto_rows_batch__complete_data(self): + byte_cursor = _ByteCursor() + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": {"columns": [{"name": "test1", "type_": TYPE_INT}]} + } + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"123"}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"456"}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"789"}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": b"0"}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"abc"}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"def"}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"ghi"}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": b"j"}, + "resume_token": b"token2", + } + ), + ] + assert byte_cursor.metadata is None + byte_cursor_iter = pass_values_to_byte_cursor(byte_cursor, stream) + value = next(byte_cursor_iter) + assert value == b"1234567890" + assert byte_cursor._resume_token == b"token1" + assert byte_cursor.metadata.columns[0].column_name == "test1" + + value = next(byte_cursor_iter) + assert value == b"abcdefghij" + assert byte_cursor._resume_token == b"token2" + + def test__proto_rows_batch__empty_proto_rows_batch(self): + byte_cursor = _ByteCursor() + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": {"columns": [{"name": "test1", "type_": TYPE_INT}]} + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {}, "resume_token": b"token1"} + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"123"}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": b"0"}, + "resume_token": b"token2", + } + ), + ] + + byte_cursor_iter = pass_values_to_byte_cursor(byte_cursor, stream) + value = next(byte_cursor_iter) + assert value == b"1230" + assert byte_cursor._resume_token == b"token2" + + def test__proto_rows_batch__no_proto_rows_batch(self): + byte_cursor = _ByteCursor() + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": {"columns": [{"name": "test1", "type_": TYPE_INT}]} + } + ), + ExecuteQueryResponse(results={"resume_token": b"token1"}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"123"}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": b"0"}, + "resume_token": b"token2", + } + ), + ] + + byte_cursor_iter = pass_values_to_byte_cursor(byte_cursor, stream) + value = next(byte_cursor_iter) + assert value == b"1230" + assert byte_cursor._resume_token == b"token2" + + def test__proto_rows_batch__no_resume_token_at_the_end_of_stream(self): + byte_cursor = _ByteCursor() + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": {"columns": [{"name": "test1", "type_": TYPE_INT}]} + } + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": b"0"}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"abc"}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"def"}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": b"ghi"}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": b"j"}, + } + ), + ] + assert byte_cursor.metadata is None + assert byte_cursor.consume(stream[0]) is None + value = byte_cursor.consume(stream[1]) + assert value == b"0" + assert byte_cursor._resume_token == b"token1" + assert byte_cursor.metadata.columns[0].column_name == "test1" + + assert byte_cursor.consume(stream[2]) is None + assert byte_cursor.consume(stream[3]) is None + assert byte_cursor.consume(stream[3]) is None + assert byte_cursor.consume(stream[4]) is None + assert byte_cursor.consume(stream[5]) is None diff --git a/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py new file mode 100644 index 000000000..914a0920a --- /dev/null +++ b/tests/unit/data/execute_query/test_execute_query_parameters_parsing.py @@ -0,0 +1,134 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from google.cloud.bigtable.data.execute_query._parameters_formatting import ( + _format_execute_query_params, +) +from google.cloud.bigtable.data.execute_query.metadata import SqlType +from google.cloud.bigtable.data.execute_query.values import Struct +import datetime + +from google.type import date_pb2 +from google.api_core.datetime_helpers import DatetimeWithNanoseconds + + +timestamp = int( + datetime.datetime(2024, 5, 12, 17, 44, 12, tzinfo=datetime.timezone.utc).timestamp() +) +dt_micros_non_zero = DatetimeWithNanoseconds( + 2024, 5, 12, 17, 44, 12, 123, nanosecond=0, tzinfo=datetime.timezone.utc +).timestamp_pb() +dt_nanos_zero = DatetimeWithNanoseconds( + 2024, 5, 12, 17, 44, 12, nanosecond=0, tzinfo=datetime.timezone.utc +).timestamp_pb() +dt_nanos_non_zero = DatetimeWithNanoseconds( + 2024, 5, 12, 17, 44, 12, nanosecond=12, tzinfo=datetime.timezone.utc +).timestamp_pb() +pb_date = date_pb2.Date(year=2024, month=5, day=15) + + +@pytest.mark.parametrize( + "input_value,value_field,type_field,expected_value", + [ + (1, "int_value", "int64_type", 1), + ("2", "string_value", "string_type", "2"), + (b"3", "bytes_value", "bytes_type", b"3"), + (True, "bool_value", "bool_type", True), + ( + datetime.datetime.fromtimestamp(timestamp), + "timestamp_value", + "timestamp_type", + dt_nanos_zero, + ), + ( + datetime.datetime( + 2024, 5, 12, 17, 44, 12, 123, tzinfo=datetime.timezone.utc + ), + "timestamp_value", + "timestamp_type", + dt_micros_non_zero, + ), + (datetime.date(2024, 5, 15), "date_value", "date_type", pb_date), + ( + DatetimeWithNanoseconds( + 2024, 5, 12, 17, 44, 12, nanosecond=12, tzinfo=datetime.timezone.utc + ), + "timestamp_value", + "timestamp_type", + dt_nanos_non_zero, + ), + ], +) +def test_instance_execute_query_parameters_simple_types_parsing( + input_value, value_field, type_field, expected_value +): + result = _format_execute_query_params( + { + "test": input_value, + }, + None, + ) + assert result["test"][value_field] == expected_value + assert type_field in result["test"]["type_"] + + +def test_instance_execute_query_parameters_not_supported_types(): + with pytest.raises(ValueError): + _format_execute_query_params({"test1": 1.1}, None) + + with pytest.raises(ValueError): + _format_execute_query_params({"test1": {"a": 1}}, None) + + with pytest.raises(ValueError): + _format_execute_query_params({"test1": [1]}, None) + + with pytest.raises(ValueError): + _format_execute_query_params({"test1": Struct([("field1", 1)])}, None) + + with pytest.raises(NotImplementedError, match="not supported"): + _format_execute_query_params( + {"test1": {"a": 1}}, + { + "test1": SqlType.Map(SqlType.String(), SqlType.Int64()), + }, + ) + + with pytest.raises(NotImplementedError, match="not supported"): + _format_execute_query_params( + {"test1": [1]}, + { + "test1": SqlType.Array(SqlType.Int64()), + }, + ) + + with pytest.raises(NotImplementedError, match="not supported"): + _format_execute_query_params( + {"test1": Struct([("field1", 1)])}, + {"test1": SqlType.Struct([("field1", SqlType.Int64())])}, + ) + + +def test_instance_execute_query_parameters_not_match(): + with pytest.raises(ValueError, match="test2"): + _format_execute_query_params( + { + "test1": 1, + "test2": 1, + }, + { + "test1": SqlType.Int64(), + "test2": SqlType.String(), + }, + ) diff --git a/tests/unit/data/execute_query/test_query_result_parsing_utils.py b/tests/unit/data/execute_query/test_query_result_parsing_utils.py new file mode 100644 index 000000000..ff7211654 --- /dev/null +++ b/tests/unit/data/execute_query/test_query_result_parsing_utils.py @@ -0,0 +1,715 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from google.cloud.bigtable.data.execute_query.values import Struct +from google.cloud.bigtable_v2 import Type as PBType, Value as PBValue +from google.cloud.bigtable.data.execute_query._query_result_parsing_utils import ( + _parse_pb_value_to_python_value, +) +from google.cloud.bigtable.data.execute_query.metadata import ( + _pb_type_to_metadata_type, + SqlType, +) + +from google.type import date_pb2 +from google.api_core.datetime_helpers import DatetimeWithNanoseconds + +import datetime + +from ._testing import TYPE_INT + +TYPE_BYTES = {"bytes_type": {}} +TYPE_TIMESTAMP = {"timestamp_type": {}} + + +class TestQueryResultParsingUtils: + @pytest.mark.parametrize( + "type_dict,value_dict,expected_metadata_type,expected_value", + [ + (TYPE_INT, {"int_value": 1}, SqlType.Int64, 1), + ( + {"string_type": {}}, + {"string_value": "test"}, + SqlType.String, + "test", + ), + ({"bool_type": {}}, {"bool_value": False}, SqlType.Bool, False), + ( + {"bytes_type": {}}, + {"bytes_value": b"test"}, + SqlType.Bytes, + b"test", + ), + ( + {"float64_type": {}}, + {"float_value": 17.21}, + SqlType.Float64, + 17.21, + ), + ( + {"timestamp_type": {}}, + {"timestamp_value": {"seconds": 1715864647, "nanos": 12}}, + SqlType.Timestamp, + DatetimeWithNanoseconds( + 2024, 5, 16, 13, 4, 7, nanosecond=12, tzinfo=datetime.timezone.utc + ), + ), + ( + {"date_type": {}}, + {"date_value": {"year": 1800, "month": 12, "day": 0}}, + SqlType.Date, + date_pb2.Date(year=1800, month=12, day=0), + ), + ], + ) + def test_basic_types( + self, type_dict, value_dict, expected_metadata_type, expected_value + ): + _type = PBType(type_dict) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is expected_metadata_type + value = PBValue(value_dict) + assert ( + _parse_pb_value_to_python_value(value._pb, metadata_type) == expected_value + ) + + # Larger test cases were extracted for readability + def test__array(self): + _type = PBType({"array_type": {"element_type": TYPE_INT}}) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Array + assert type(metadata_type.element_type) is SqlType.Int64 + value = PBValue( + { + "array_value": { + "values": [ + {"int_value": 1}, + {"int_value": 2}, + {"int_value": 3}, + {"int_value": 4}, + ] + } + } + ) + assert _parse_pb_value_to_python_value(value._pb, metadata_type) == [1, 2, 3, 4] + + def test__struct(self): + _type = PBType( + { + "struct_type": { + "fields": [ + { + "field_name": "field1", + "type_": TYPE_INT, + }, + { + "field_name": None, + "type_": {"string_type": {}}, + }, + { + "field_name": "field3", + "type_": {"array_type": {"element_type": TYPE_INT}}, + }, + { + "field_name": "field3", + "type_": {"string_type": {}}, + }, + ] + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ + {"int_value": 1}, + {"string_value": "test2"}, + { + "array_value": { + "values": [ + {"int_value": 2}, + {"int_value": 3}, + {"int_value": 4}, + {"int_value": 5}, + ] + } + }, + {"string_value": "test4"}, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Struct + assert type(metadata_type["field1"]) is SqlType.Int64 + assert type(metadata_type[1]) is SqlType.String + assert type(metadata_type[2]) is SqlType.Array + assert type(metadata_type[2].element_type) is SqlType.Int64 + assert type(metadata_type[3]) is SqlType.String + + # duplicate fields not accesible by name + with pytest.raises(KeyError, match="Ambigious field name"): + metadata_type["field3"] + + result = _parse_pb_value_to_python_value(value._pb, metadata_type) + assert isinstance(result, Struct) + assert result["field1"] == result[0] == 1 + assert result[1] == "test2" + + # duplicate fields not accesible by name + with pytest.raises(KeyError, match="Ambigious field name"): + result["field3"] + + # duplicate fields accessible by index + assert result[2] == [2, 3, 4, 5] + assert result[3] == "test4" + + def test__array_of_structs(self): + _type = PBType( + { + "array_type": { + "element_type": { + "struct_type": { + "fields": [ + { + "field_name": "field1", + "type_": TYPE_INT, + }, + { + "field_name": None, + "type_": {"string_type": {}}, + }, + { + "field_name": "field3", + "type_": {"bool_type": {}}, + }, + ] + } + } + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ + { + "array_value": { + "values": [ + {"int_value": 1}, + {"string_value": "test1"}, + {"bool_value": True}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 2}, + {"string_value": "test2"}, + {"bool_value": False}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 3}, + {"string_value": "test3"}, + {"bool_value": True}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 4}, + {"string_value": "test4"}, + {"bool_value": False}, + ] + } + }, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Array + assert type(metadata_type.element_type) is SqlType.Struct + assert type(metadata_type.element_type["field1"]) is SqlType.Int64 + assert type(metadata_type.element_type[1]) is SqlType.String + assert type(metadata_type.element_type["field3"]) is SqlType.Bool + + result = _parse_pb_value_to_python_value(value._pb, metadata_type) + assert isinstance(result, list) + assert len(result) == 4 + + assert isinstance(result[0], Struct) + assert result[0]["field1"] == 1 + assert result[0][1] == "test1" + assert result[0]["field3"] + + assert isinstance(result[1], Struct) + assert result[1]["field1"] == 2 + assert result[1][1] == "test2" + assert not result[1]["field3"] + + assert isinstance(result[2], Struct) + assert result[2]["field1"] == 3 + assert result[2][1] == "test3" + assert result[2]["field3"] + + assert isinstance(result[3], Struct) + assert result[3]["field1"] == 4 + assert result[3][1] == "test4" + assert not result[3]["field3"] + + def test__map(self): + _type = PBType( + { + "map_type": { + "key_type": TYPE_INT, + "value_type": {"string_type": {}}, + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ + { + "array_value": { + "values": [ + {"int_value": 1}, + {"string_value": "test1"}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 2}, + {"string_value": "test2"}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 3}, + {"string_value": "test3"}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 4}, + {"string_value": "test4"}, + ] + } + }, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Int64 + assert type(metadata_type.value_type) is SqlType.String + + result = _parse_pb_value_to_python_value(value._pb, metadata_type) + assert isinstance(result, dict) + assert len(result) == 4 + + assert result == { + 1: "test1", + 2: "test2", + 3: "test3", + 4: "test4", + } + + def test__map_repeated_values(self): + _type = PBType( + { + "map_type": { + "key_type": TYPE_INT, + "value_type": {"string_type": {}}, + } + }, + ) + value = PBValue( + { + "array_value": { + "values": [ + { + "array_value": { + "values": [ + {"int_value": 1}, + {"string_value": "test1"}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 1}, + {"string_value": "test2"}, + ] + } + }, + { + "array_value": { + "values": [ + {"int_value": 1}, + {"string_value": "test3"}, + ] + } + }, + ] + } + } + ) + + metadata_type = _pb_type_to_metadata_type(_type) + result = _parse_pb_value_to_python_value(value._pb, metadata_type) + assert len(result) == 1 + + assert result == { + 1: "test3", + } + + def test__map_of_maps_of_structs(self): + _type = PBType( + { + "map_type": { + "key_type": TYPE_INT, + "value_type": { + "map_type": { + "key_type": {"string_type": {}}, + "value_type": { + "struct_type": { + "fields": [ + { + "field_name": "field1", + "type_": TYPE_INT, + }, + { + "field_name": "field2", + "type_": {"string_type": {}}, + }, + ] + } + }, + } + }, + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ # list of (int, map) tuples + { + "array_value": { + "values": [ # (int, map) tuple + {"int_value": 1}, + { + "array_value": { + "values": [ # list of (str, struct) tuples + { + "array_value": { + "values": [ # (str, struct) tuple + {"string_value": "1_1"}, + { + "array_value": { + "values": [ + { + "int_value": 1 + }, + { + "string_value": "test1" + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (str, struct) tuple + {"string_value": "1_2"}, + { + "array_value": { + "values": [ + { + "int_value": 2 + }, + { + "string_value": "test2" + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (int, map) tuple + {"int_value": 2}, + { + "array_value": { + "values": [ # list of (str, struct) tuples + { + "array_value": { + "values": [ # (str, struct) tuple + {"string_value": "2_1"}, + { + "array_value": { + "values": [ + { + "int_value": 3 + }, + { + "string_value": "test3" + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (str, struct) tuple + {"string_value": "2_2"}, + { + "array_value": { + "values": [ + { + "int_value": 4 + }, + { + "string_value": "test4" + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + } + ) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Int64 + assert type(metadata_type.value_type) is SqlType.Map + assert type(metadata_type.value_type.key_type) is SqlType.String + assert type(metadata_type.value_type.value_type) is SqlType.Struct + assert type(metadata_type.value_type.value_type["field1"]) is SqlType.Int64 + assert type(metadata_type.value_type.value_type["field2"]) is SqlType.String + result = _parse_pb_value_to_python_value(value._pb, metadata_type) + + assert result[1]["1_1"]["field1"] == 1 + assert result[1]["1_1"]["field2"] == "test1" + + assert result[1]["1_2"]["field1"] == 2 + assert result[1]["1_2"]["field2"] == "test2" + + assert result[2]["2_1"]["field1"] == 3 + assert result[2]["2_1"]["field2"] == "test3" + + assert result[2]["2_2"]["field1"] == 4 + assert result[2]["2_2"]["field2"] == "test4" + + def test__map_of_lists_of_structs(self): + _type = PBType( + { + "map_type": { + "key_type": TYPE_BYTES, + "value_type": { + "array_type": { + "element_type": { + "struct_type": { + "fields": [ + { + "field_name": "timestamp", + "type_": TYPE_TIMESTAMP, + }, + { + "field_name": "value", + "type_": TYPE_BYTES, + }, + ] + } + }, + } + }, + } + } + ) + value = PBValue( + { + "array_value": { + "values": [ # list of (byte, list) tuples + { + "array_value": { + "values": [ # (byte, list) tuple + {"bytes_value": b"key1"}, + { + "array_value": { + "values": [ # list of structs + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 1111111111 + } + }, + { + "bytes_value": b"key1-value1" + }, + ] + } + }, + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 2222222222 + } + }, + { + "bytes_value": b"key1-value2" + }, + ] + } + }, + ] + } + }, + ] + } + }, + { + "array_value": { + "values": [ # (byte, list) tuple + {"bytes_value": b"key2"}, + { + "array_value": { + "values": [ # list of structs + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 3333333333 + } + }, + { + "bytes_value": b"key2-value1" + }, + ] + } + }, + { + "array_value": { + "values": [ # (timestamp, bytes) tuple + { + "timestamp_value": { + "seconds": 4444444444 + } + }, + { + "bytes_value": b"key2-value2" + }, + ] + } + }, + ] + } + }, + ] + } + }, + ] + } + } + ) + metadata_type = _pb_type_to_metadata_type(_type) + assert type(metadata_type) is SqlType.Map + assert type(metadata_type.key_type) is SqlType.Bytes + assert type(metadata_type.value_type) is SqlType.Array + assert type(metadata_type.value_type.element_type) is SqlType.Struct + assert ( + type(metadata_type.value_type.element_type["timestamp"]) + is SqlType.Timestamp + ) + assert type(metadata_type.value_type.element_type["value"]) is SqlType.Bytes + result = _parse_pb_value_to_python_value(value._pb, metadata_type) + + timestamp1 = DatetimeWithNanoseconds( + 2005, 3, 18, 1, 58, 31, tzinfo=datetime.timezone.utc + ) + timestamp2 = DatetimeWithNanoseconds( + 2040, 6, 2, 3, 57, 2, tzinfo=datetime.timezone.utc + ) + timestamp3 = DatetimeWithNanoseconds( + 2075, 8, 18, 5, 55, 33, tzinfo=datetime.timezone.utc + ) + timestamp4 = DatetimeWithNanoseconds( + 2110, 11, 3, 7, 54, 4, tzinfo=datetime.timezone.utc + ) + + assert result[b"key1"][0]["timestamp"] == timestamp1 + assert result[b"key1"][0]["value"] == b"key1-value1" + assert result[b"key1"][1]["timestamp"] == timestamp2 + assert result[b"key1"][1]["value"] == b"key1-value2" + assert result[b"key2"][0]["timestamp"] == timestamp3 + assert result[b"key2"][0]["value"] == b"key2-value1" + assert result[b"key2"][1]["timestamp"] == timestamp4 + assert result[b"key2"][1]["value"] == b"key2-value2" + + def test__invalid_type_throws_exception(self): + _type = PBType({"string_type": {}}) + value = PBValue({"int_value": 1}) + metadata_type = _pb_type_to_metadata_type(_type) + + with pytest.raises( + ValueError, + match="string_value field for String type not found in a Value.", + ): + _parse_pb_value_to_python_value(value._pb, metadata_type) diff --git a/tests/unit/data/execute_query/test_query_result_row_reader.py b/tests/unit/data/execute_query/test_query_result_row_reader.py new file mode 100644 index 000000000..2bb1e4da0 --- /dev/null +++ b/tests/unit/data/execute_query/test_query_result_row_reader.py @@ -0,0 +1,310 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest import mock +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse +from google.cloud.bigtable_v2.types.data import Value as PBValue +from google.cloud.bigtable.data.execute_query._reader import _QueryResultRowReader + +from google.cloud.bigtable.data.execute_query.metadata import ProtoMetadata, SqlType + +import google.cloud.bigtable.data.execute_query._reader +from ._testing import TYPE_INT, proto_rows_bytes + + +class TestQueryResultRowReader: + def test__single_values_received(self): + byte_cursor = mock.Mock( + metadata=ProtoMetadata( + [("test1", SqlType.Int64()), ("test2", SqlType.Int64())] + ) + ) + values = [ + proto_rows_bytes({"int_value": 1}), + proto_rows_bytes({"int_value": 2}), + proto_rows_bytes({"int_value": 3}), + ] + + reader = _QueryResultRowReader(byte_cursor) + + assert reader.consume(values[0]) is None + result = reader.consume(values[1]) + assert len(result) == 1 + assert len(result[0]) == 2 + assert reader.consume(values[2]) is None + + def test__multiple_rows_received(self): + values = [ + proto_rows_bytes( + {"int_value": 1}, + {"int_value": 2}, + {"int_value": 3}, + {"int_value": 4}, + ), + proto_rows_bytes({"int_value": 5}, {"int_value": 6}), + proto_rows_bytes({"int_value": 7}, {"int_value": 8}), + ] + + byte_cursor = mock.Mock( + metadata=ProtoMetadata( + [("test1", SqlType.Int64()), ("test2", SqlType.Int64())] + ) + ) + + reader = _QueryResultRowReader(byte_cursor) + + result = reader.consume(values[0]) + assert len(result) == 2 + assert len(result[0]) == 2 + assert result[0][0] == result[0]["test1"] == 1 + assert result[0][1] == result[0]["test2"] == 2 + + assert len(result[1]) == 2 + assert result[1][0] == result[1]["test1"] == 3 + assert result[1][1] == result[1]["test2"] == 4 + + result = reader.consume(values[1]) + assert len(result) == 1 + assert len(result[0]) == 2 + assert result[0][0] == result[0]["test1"] == 5 + assert result[0][1] == result[0]["test2"] == 6 + + result = reader.consume(values[2]) + assert len(result) == 1 + assert len(result[0]) == 2 + assert result[0][0] == result[0]["test1"] == 7 + assert result[0][1] == result[0]["test2"] == 8 + + def test__received_values_are_passed_to_parser_in_batches(self): + byte_cursor = mock.Mock( + metadata=ProtoMetadata( + [("test1", SqlType.Int64()), ("test2", SqlType.Int64())] + ) + ) + + assert SqlType.Struct([("a", SqlType.Int64())]) == SqlType.Struct( + [("a", SqlType.Int64())] + ) + assert SqlType.Struct([("a", SqlType.String())]) != SqlType.Struct( + [("a", SqlType.Int64())] + ) + assert SqlType.Struct([("a", SqlType.Int64())]) != SqlType.Struct( + [("b", SqlType.Int64())] + ) + + assert SqlType.Array(SqlType.Int64()) == SqlType.Array(SqlType.Int64()) + assert SqlType.Array(SqlType.Int64()) != SqlType.Array(SqlType.String()) + + assert SqlType.Map(SqlType.Int64(), SqlType.String()) == SqlType.Map( + SqlType.Int64(), SqlType.String() + ) + assert SqlType.Map(SqlType.Int64(), SqlType.String()) != SqlType.Map( + SqlType.String(), SqlType.String() + ) + + values = [ + {"int_value": 1}, + {"int_value": 2}, + ] + + reader = _QueryResultRowReader(byte_cursor) + with mock.patch.object( + google.cloud.bigtable.data.execute_query._reader, + "_parse_pb_value_to_python_value", + ) as parse_mock: + reader.consume(proto_rows_bytes(values[0])) + parse_mock.assert_not_called() + reader.consume(proto_rows_bytes(values[1])) + parse_mock.assert_has_calls( + [ + mock.call(PBValue(values[0]), SqlType.Int64()), + mock.call(PBValue(values[1]), SqlType.Int64()), + ] + ) + + def test__parser_errors_are_forwarded(self): + byte_cursor = mock.Mock(metadata=ProtoMetadata([("test1", SqlType.Int64())])) + + values = [ + {"string_value": "test"}, + ] + + reader = _QueryResultRowReader(byte_cursor) + with mock.patch.object( + google.cloud.bigtable.data.execute_query._reader, + "_parse_pb_value_to_python_value", + side_effect=ValueError("test"), + ) as parse_mock: + with pytest.raises(ValueError, match="test"): + reader.consume(proto_rows_bytes(values[0])) + + parse_mock.assert_has_calls( + [ + mock.call(PBValue(values[0]), SqlType.Int64()), + ] + ) + + def test__multiple_proto_rows_received_with_one_resume_token(self): + from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor + + def split_bytes_into_chunks(bytes_to_split, num_chunks): + from google.cloud.bigtable.helpers import batched + + assert num_chunks <= len(bytes_to_split) + bytes_per_part = (len(bytes_to_split) - 1) // num_chunks + 1 + result = list(map(bytes, batched(bytes_to_split, bytes_per_part))) + assert len(result) == num_chunks + return result + + def pass_values_to_byte_cursor(byte_cursor, iterable): + for value in iterable: + result = byte_cursor.consume(value) + if result is not None: + yield result + + proto_rows = [ + proto_rows_bytes({"int_value": 1}, {"int_value": 2}), + proto_rows_bytes({"int_value": 3}, {"int_value": 4}), + proto_rows_bytes({"int_value": 5}, {"int_value": 6}), + ] + + messages = [ + *split_bytes_into_chunks(proto_rows[0], num_chunks=2), + *split_bytes_into_chunks(proto_rows[1], num_chunks=3), + proto_rows[2], + ] + + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": { + "columns": [ + {"name": "test1", "type_": TYPE_INT}, + {"name": "test2", "type_": TYPE_INT}, + ] + } + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[0]}} + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[1]}} + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[2]}} + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[3]}} + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[4]}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[5]}, + "resume_token": b"token2", + } + ), + ] + + byte_cursor = _ByteCursor() + + reader = _QueryResultRowReader(byte_cursor) + + byte_cursor_iter = pass_values_to_byte_cursor(byte_cursor, stream) + + returned_values = [] + + def intercept_return_values(func): + nonlocal intercept_return_values + + def wrapped(*args, **kwargs): + value = func(*args, **kwargs) + returned_values.append(value) + return value + + return wrapped + + with mock.patch.object( + reader, + "_parse_proto_rows", + wraps=intercept_return_values(reader._parse_proto_rows), + ): + result = reader.consume(next(byte_cursor_iter)) + + # Despite the fact that two ProtoRows were received, a single resume_token after the second ProtoRows object forces us to parse them together. + # We will interpret them as one larger ProtoRows object. + assert len(returned_values) == 1 + assert len(returned_values[0]) == 4 + assert returned_values[0][0].int_value == 1 + assert returned_values[0][1].int_value == 2 + assert returned_values[0][2].int_value == 3 + assert returned_values[0][3].int_value == 4 + + assert len(result) == 2 + assert len(result[0]) == 2 + assert result[0][0] == 1 + assert result[0]["test1"] == 1 + assert result[0][1] == 2 + assert result[0]["test2"] == 2 + assert len(result[1]) == 2 + assert result[1][0] == 3 + assert result[1]["test1"] == 3 + assert result[1][1] == 4 + assert result[1]["test2"] == 4 + assert byte_cursor._resume_token == b"token1" + + returned_values = [] + with mock.patch.object( + reader, + "_parse_proto_rows", + wraps=intercept_return_values(reader._parse_proto_rows), + ): + result = reader.consume(next(byte_cursor_iter)) + + assert len(result) == 1 + assert len(result[0]) == 2 + assert result[0][0] == 5 + assert result[0]["test1"] == 5 + assert result[0][1] == 6 + assert result[0]["test2"] == 6 + assert byte_cursor._resume_token == b"token2" + + +class TestProtoMetadata: + def test__duplicate_column_names(self): + metadata = ProtoMetadata( + [ + ("test1", SqlType.Int64()), + ("test2", SqlType.Bytes()), + ("test2", SqlType.String()), + ] + ) + assert metadata[0].column_name == "test1" + assert metadata["test1"].column_type == SqlType.Int64() + + # duplicate columns not accesible by name + with pytest.raises(KeyError, match="Ambigious column name"): + metadata["test2"] + + # duplicate columns accessible by index + assert metadata[1].column_type == SqlType.Bytes() + assert metadata[1].column_name == "test2" + assert metadata[2].column_type == SqlType.String() + assert metadata[2].column_name == "test2" diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 5a9c500ed..12ab3181e 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -23,16 +23,31 @@ class TestMakeMetadata: @pytest.mark.parametrize( - "table,profile,expected", + "table,profile,instance,expected", [ - ("table", "profile", "table_name=table&app_profile_id=profile"), - ("table", None, "table_name=table"), + ("table", "profile", None, "table_name=table&app_profile_id=profile"), + ("table", None, None, "table_name=table"), + (None, None, "instance", "name=instance"), + (None, "profile", None, "app_profile_id=profile"), + (None, "profile", "instance", "name=instance&app_profile_id=profile"), ], ) - def test__make_metadata(self, table, profile, expected): - metadata = _helpers._make_metadata(table, profile) + def test__make_metadata(self, table, profile, instance, expected): + metadata = _helpers._make_metadata(table, profile, instance) assert metadata == [("x-goog-request-params", expected)] + @pytest.mark.parametrize( + "table,profile,instance", + [ + ("table", None, "instance"), + ("table", "profile", "instance"), + (None, None, None), + ], + ) + def test__make_metadata_invalid_params(self, table, profile, instance): + with pytest.raises(ValueError): + _helpers._make_metadata(table, profile, instance) + class TestAttemptTimeoutGenerator: @pytest.mark.parametrize( diff --git a/tests/unit/data/test_helpers.py b/tests/unit/data/test_helpers.py new file mode 100644 index 000000000..5d1ad70f8 --- /dev/null +++ b/tests/unit/data/test_helpers.py @@ -0,0 +1,45 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from google.cloud.bigtable.helpers import batched + + +class TestBatched: + @pytest.mark.parametrize( + "input_list,batch_size,expected", + [ + ([1, 2, 3, 4, 5], 3, [[1, 2, 3], [4, 5]]), + ([1, 2, 3, 4, 5, 6], 3, [[1, 2, 3], [4, 5, 6]]), + ([1, 2, 3, 4, 5], 2, [[1, 2], [3, 4], [5]]), + ([1, 2, 3, 4, 5], 1, [[1], [2], [3], [4], [5]]), + ([1, 2, 3, 4, 5], 5, [[1, 2, 3, 4, 5]]), + ([], 1, []), + ], + ) + def test_batched(self, input_list, batch_size, expected): + result = list(batched(input_list, batch_size)) + assert list(map(list, result)) == expected + + @pytest.mark.parametrize( + "input_list,batch_size", + [ + ([1], 0), + ([1], -1), + ], + ) + def test_batched_errs(self, input_list, batch_size): + with pytest.raises(ValueError): + list(batched(input_list, batch_size)) diff --git a/tests/unit/v2_client/_testing.py b/tests/unit/v2_client/_testing.py index 302d33ac1..855c0c10e 100644 --- a/tests/unit/v2_client/_testing.py +++ b/tests/unit/v2_client/_testing.py @@ -17,6 +17,9 @@ import mock +# flake8: noqa +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes + class _FakeStub(object): """Acts as a gPRC stub.""" diff --git a/tests/unit/v2_client/test_instance.py b/tests/unit/v2_client/test_instance.py index 797e4bd9c..de6844a16 100644 --- a/tests/unit/v2_client/test_instance.py +++ b/tests/unit/v2_client/test_instance.py @@ -19,6 +19,7 @@ from ._testing import _make_credentials from google.cloud.bigtable.cluster import Cluster + PROJECT = "project" INSTANCE_ID = "instance-id" INSTANCE_NAME = "projects/" + PROJECT + "/instances/" + INSTANCE_ID @@ -943,3 +944,28 @@ def _next_page(self): assert isinstance(app_profile_2, AppProfile) assert app_profile_2.name == app_profile_name2 + + +@pytest.fixture() +def data_api(): + from google.cloud.bigtable_v2.services.bigtable import BigtableClient + + data_api_mock = mock.create_autospec(BigtableClient) + data_api_mock.instance_path.return_value = ( + f"projects/{PROJECT}/instances/{INSTANCE_ID}" + ) + return data_api_mock + + +@pytest.fixture() +def client(data_api): + result = _make_client( + project="project-id", credentials=_make_credentials(), admin=True + ) + result._table_data_client = data_api + return result + + +@pytest.fixture() +def instance(client): + return client.instance(instance_id=INSTANCE_ID)