From 5d81dcc53565e8bca901f82b104f3acb95361811 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Nov 2024 11:10:00 -0700 Subject: [PATCH 01/14] remove pooled transport from gapic folder --- .../bigtable_v2/services/bigtable/client.py | 2 - .../services/bigtable/transports/__init__.py | 3 - .../transports/pooled_grpc_asyncio.py | 426 ------------------ owlbot.py | 46 -- 4 files changed, 477 deletions(-) delete mode 100644 google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py index 86fa6b3a5..92e0c0011 100644 --- a/google/cloud/bigtable_v2/services/bigtable/client.py +++ b/google/cloud/bigtable_v2/services/bigtable/client.py @@ -55,7 +55,6 @@ from .transports.base import BigtableTransport, DEFAULT_CLIENT_INFO from .transports.grpc import BigtableGrpcTransport from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport -from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport from .transports.rest import BigtableRestTransport @@ -70,7 +69,6 @@ class BigtableClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]] _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport - _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport _transport_registry["rest"] = BigtableRestTransport def get_transport_class( diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py index ae5c1cf72..ae007bc2b 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/__init__.py @@ -19,7 +19,6 @@ from .base import BigtableTransport from .grpc import BigtableGrpcTransport from .grpc_asyncio import BigtableGrpcAsyncIOTransport -from .pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport from .rest import BigtableRestTransport from .rest import BigtableRestInterceptor @@ -28,14 +27,12 @@ _transport_registry = OrderedDict() # type: Dict[str, Type[BigtableTransport]] _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport -_transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport _transport_registry["rest"] = BigtableRestTransport __all__ = ( "BigtableTransport", "BigtableGrpcTransport", "BigtableGrpcAsyncIOTransport", - "PooledBigtableGrpcAsyncIOTransport", "BigtableRestTransport", "BigtableRestInterceptor", ) diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py deleted file mode 100644 index 372e5796d..000000000 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py +++ /dev/null @@ -1,426 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import asyncio -import warnings -from functools import partialmethod -from functools import partial -from typing import ( - Awaitable, - Callable, - Dict, - Optional, - Sequence, - Tuple, - Union, - List, - Type, -) - -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.bigtable_v2.types import bigtable -from .base import BigtableTransport, DEFAULT_CLIENT_INFO -from .grpc_asyncio import BigtableGrpcAsyncIOTransport - - -class PooledMultiCallable: - def __init__(self, channel_pool: "PooledChannel", *args, **kwargs): - self._init_args = args - self._init_kwargs = kwargs - self.next_channel_fn = channel_pool.next_channel - - -class PooledUnaryUnaryMultiCallable(PooledMultiCallable, aio.UnaryUnaryMultiCallable): - def __call__(self, *args, **kwargs) -> aio.UnaryUnaryCall: - return self.next_channel_fn().unary_unary( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledUnaryStreamMultiCallable(PooledMultiCallable, aio.UnaryStreamMultiCallable): - def __call__(self, *args, **kwargs) -> aio.UnaryStreamCall: - return self.next_channel_fn().unary_stream( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledStreamUnaryMultiCallable(PooledMultiCallable, aio.StreamUnaryMultiCallable): - def __call__(self, *args, **kwargs) -> aio.StreamUnaryCall: - return self.next_channel_fn().stream_unary( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledStreamStreamMultiCallable( - PooledMultiCallable, aio.StreamStreamMultiCallable -): - def __call__(self, *args, **kwargs) -> aio.StreamStreamCall: - return self.next_channel_fn().stream_stream( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledChannel(aio.Channel): - def __init__( - self, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - quota_project_id: Optional[str] = None, - default_scopes: Optional[Sequence[str]] = None, - scopes: Optional[Sequence[str]] = None, - default_host: Optional[str] = None, - insecure: bool = False, - **kwargs, - ): - self._pool: List[aio.Channel] = [] - self._next_idx = 0 - if insecure: - self._create_channel = partial(aio.insecure_channel, host) - else: - self._create_channel = partial( - grpc_helpers_async.create_channel, - target=host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=default_scopes, - scopes=scopes, - default_host=default_host, - **kwargs, - ) - for i in range(pool_size): - self._pool.append(self._create_channel()) - - def next_channel(self) -> aio.Channel: - channel = self._pool[self._next_idx] - self._next_idx = (self._next_idx + 1) % len(self._pool) - return channel - - def unary_unary(self, *args, **kwargs) -> grpc.aio.UnaryUnaryMultiCallable: - return PooledUnaryUnaryMultiCallable(self, *args, **kwargs) - - def unary_stream(self, *args, **kwargs) -> grpc.aio.UnaryStreamMultiCallable: - return PooledUnaryStreamMultiCallable(self, *args, **kwargs) - - def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: - return PooledStreamUnaryMultiCallable(self, *args, **kwargs) - - def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: - return PooledStreamStreamMultiCallable(self, *args, **kwargs) - - async def close(self, grace=None): - close_fns = [channel.close(grace=grace) for channel in self._pool] - return await asyncio.gather(*close_fns) - - async def channel_ready(self): - ready_fns = [channel.channel_ready() for channel in self._pool] - return asyncio.gather(*ready_fns) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: - raise NotImplementedError() - - async def wait_for_state_change(self, last_observed_state): - raise NotImplementedError() - - async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None - ) -> aio.Channel: - """ - Replaces a channel in the pool with a fresh one. - - The `new_channel` will start processing new requests immidiately, - but the old channel will continue serving existing clients for `grace` seconds - - Args: - channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are - finished. If a grace period is not specified (by passing None for - grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one - new_channel(grpc.aio.Channel): a new channel to insert into the pool - at `channel_idx`. If `None`, a new channel will be created. - """ - if channel_idx >= len(self._pool) or channel_idx < 0: - raise ValueError( - f"invalid channel_idx {channel_idx} for pool size {len(self._pool)}" - ) - if new_channel is None: - new_channel = self._create_channel() - old_channel = self._pool[channel_idx] - self._pool[channel_idx] = new_channel - await asyncio.sleep(swap_sleep) - await old_channel.close(grace=grace) - return new_channel - - -class PooledBigtableGrpcAsyncIOTransport(BigtableGrpcAsyncIOTransport): - """Pooled gRPC AsyncIO backend transport for Bigtable. - - Service for reading from and writing to existing Bigtable - tables. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - - This class allows channel pooling, so multiple channels can be used concurrently - when making requests. Channels are rotated in a round-robin fashion. - """ - - @classmethod - def with_fixed_size(cls, pool_size) -> Type["PooledBigtableGrpcAsyncIOTransport"]: - """ - Creates a new class with a fixed channel pool size. - - A fixed channel pool makes compatibility with other transports easier, - as the initializer signature is the same. - """ - - class PooledTransportFixed(cls): - __init__ = partialmethod(cls.__init__, pool_size=pool_size) - - PooledTransportFixed.__name__ = f"{cls.__name__}_{pool_size}" - PooledTransportFixed.__qualname__ = PooledTransportFixed.__name__ - return PooledTransportFixed - - @classmethod - def create_channel( - cls, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a PooledChannel object, representing a pool of gRPC AsyncIO channels - Args: - pool_size (int): The number of channels in the pool. - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - PooledChannel: a channel pool object - """ - - return PooledChannel( - pool_size, - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs, - ) - - def __init__( - self, - *, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - pool_size (int): the number of grpc channels to maintain in a pool - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - ValueError: if ``pool_size`` <= 0 - """ - if pool_size <= 0: - raise ValueError(f"invalid pool_size: {pool_size}") - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - BigtableTransport.__init__( - self, - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - self._quota_project_id = quota_project_id - self._grpc_channel = type(self).create_channel( - pool_size, - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=self._quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @property - def pool_size(self) -> int: - """The number of grpc channels in the pool.""" - return len(self._grpc_channel._pool) - - @property - def channels(self) -> List[grpc.Channel]: - """Acccess the internal list of grpc channels.""" - return self._grpc_channel._pool - - async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None - ) -> aio.Channel: - """ - Replaces a channel in the pool with a fresh one. - - The `new_channel` will start processing new requests immidiately, - but the old channel will continue serving existing clients for `grace` seconds - - Args: - channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are - finished. If a grace period is not specified (by passing None for - grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one - new_channel(grpc.aio.Channel): a new channel to insert into the pool - at `channel_idx`. If `None`, a new channel will be created. - """ - return await self._grpc_channel.replace_channel( - channel_idx, grace, swap_sleep, new_channel - ) - - -__all__ = ("PooledBigtableGrpcAsyncIOTransport",) diff --git a/owlbot.py b/owlbot.py index 0ec4cd61c..c5f0acf6c 100644 --- a/owlbot.py +++ b/owlbot.py @@ -97,52 +97,6 @@ def get_staging_dirs( s.move(templated_files, excludes=[".coveragerc", "README.rst", ".github/release-please.yml", "noxfile.py"]) -# ---------------------------------------------------------------------------- -# Customize gapics to include PooledBigtableGrpcAsyncIOTransport -# ---------------------------------------------------------------------------- -def insert(file, before_line, insert_line, after_line, escape=None): - target = before_line + "\n" + after_line - if escape: - for c in escape: - target = target.replace(c, '\\' + c) - replacement = before_line + "\n" + insert_line + "\n" + after_line - s.replace(file, target, replacement) - - -insert( - "google/cloud/bigtable_v2/services/bigtable/client.py", - "from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport", - "from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport", - "from .transports.rest import BigtableRestTransport" -) -insert( - "google/cloud/bigtable_v2/services/bigtable/client.py", - ' _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport', - ' _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport', - ' _transport_registry["rest"] = BigtableRestTransport', - escape='[]"' -) -insert( - "google/cloud/bigtable_v2/services/bigtable/transports/__init__.py", - '_transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport', - '_transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport', - '_transport_registry["rest"] = BigtableRestTransport', - escape='[]"' -) -insert( - "google/cloud/bigtable_v2/services/bigtable/transports/__init__.py", - "from .grpc_asyncio import BigtableGrpcAsyncIOTransport", - "from .pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport", - "from .rest import BigtableRestTransport" -) -insert( - "google/cloud/bigtable_v2/services/bigtable/transports/__init__.py", - ' "BigtableGrpcAsyncIOTransport",', - ' "PooledBigtableGrpcAsyncIOTransport",', - ' "BigtableRestTransport",', - escape='"' -) - # ---------------------------------------------------------------------------- # Patch duplicate routing header: https://github.com/googleapis/gapic-generator-python/issues/2078 # ---------------------------------------------------------------------------- From f27dea95c411073b514c75bdbc733adf86f495ba Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Nov 2024 11:10:57 -0700 Subject: [PATCH 02/14] removed submodules --- gapic-generator-fork | 1 - python-api-core | 1 - 2 files changed, 2 deletions(-) delete mode 160000 gapic-generator-fork delete mode 160000 python-api-core diff --git a/gapic-generator-fork b/gapic-generator-fork deleted file mode 160000 index b26cda7d1..000000000 --- a/gapic-generator-fork +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b26cda7d163d6e0d45c9684f328ca32fb49b799a diff --git a/python-api-core b/python-api-core deleted file mode 160000 index 17ff5f1d8..000000000 --- a/python-api-core +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 17ff5f1d83a9a6f50a0226fb0e794634bd584f17 From 345593405a88967caaecdeb24d12287b306aa99a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Nov 2024 11:55:53 -0700 Subject: [PATCH 03/14] removed channel pooling from client --- google/cloud/bigtable/data/_async/client.py | 100 ++++++++------------ 1 file changed, 37 insertions(+), 63 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 82a874918..076070f35 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -89,10 +89,7 @@ BigtableAsyncClient, ) from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - PooledChannel, -) +from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcAsyncIOTransport from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest if TYPE_CHECKING: @@ -104,11 +101,11 @@ def __init__( self, *, project: str | None = None, - pool_size: int = 3, credentials: google.auth.credentials.Credentials | None = None, client_options: dict[str, Any] | "google.api_core.client_options.ClientOptions" | None = None, + **kwargs ): """ Create a client instance for the Bigtable Data API @@ -119,8 +116,6 @@ def __init__( project: the project which the client acts on behalf of. If not passed, falls back to the default inferred from the environment. - pool_size: The number of grpc channels to maintain - in the internal channel pool. credentials: Thehe OAuth2 Credentials to use for this client. If not passed (and if no ``_http`` object is @@ -131,12 +126,9 @@ def __init__( on the client. API Endpoint should be set through client_options. Raises: RuntimeError: if called outside of an async context (no running event loop) - ValueError: if pool_size is less than 1 """ - # set up transport in registry - transport_str = f"pooled_grpc_asyncio_{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport + if "pool_size" in kwargs: + warnings.warn("pool_size no longer supported") # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -146,9 +138,16 @@ def __init__( client_options = cast( Optional[client_options_lib.ClientOptions], client_options ) + custom_channel = None self._emulator_host = os.getenv(BIGTABLE_EMULATOR) if self._emulator_host is not None: + warnings.warn( + "Connecting to Bigtable emulator at {}".format(self._emulator_host), + RuntimeWarning, + stacklevel=2, + ) # use insecure channel if emulator is set + custom_channel = grpc.aio.insecure_channel(self._emulator_host) if credentials is None: credentials = google.auth.credentials.AnonymousCredentials() if project is None: @@ -161,37 +160,20 @@ def __init__( client_options=client_options, ) self._gapic_client = BigtableAsyncClient( - transport=transport_str, credentials=credentials, client_options=client_options, client_info=client_info, + transport=lambda *args, **kwargs: BigtableGrpcAsyncIOTransport(*args, **kwargs, channel=custom_channel) ) - self.transport = cast( - PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport - ) + self.transport = self._gapic_client.transport # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[asyncio.Task[None]] = [] - if self._emulator_host is not None: - # connect to an emulator host - warnings.warn( - "Connecting to Bigtable emulator at {}".format(self._emulator_host), - RuntimeWarning, - stacklevel=2, - ) - self.transport._grpc_channel = PooledChannel( - pool_size=pool_size, - host=self._emulator_host, - insecure=True, - ) - # refresh cached stubs to use emulator pool - self.transport._stubs = {} - self.transport._prep_wrapped_messages(client_info) - else: + self._channel_refresh_task: asyncio.Task[None] | None = None + if self._emulator_host is None: # attempt to start background channel refresh tasks try: self._start_background_channel_refresh() @@ -212,33 +194,30 @@ def _client_version() -> str: def _start_background_channel_refresh(self) -> None: """ - Starts a background task to ping and warm each channel in the pool + Starts a background task to ping and warm grpc channel Raises: RuntimeError: if not called in an asyncio event loop """ - if not self._channel_refresh_tasks and not self._emulator_host: + if not self._channel_refresh_task and not self._emulator_host: # raise RuntimeError if there is no event loop asyncio.get_running_loop() - for channel_idx in range(self.transport.pool_size): - refresh_task = asyncio.create_task(self._manage_channel(channel_idx)) - if sys.version_info >= (3, 8): - # task names supported in Python 3.8+ - refresh_task.set_name( - f"{self.__class__.__name__} channel refresh {channel_idx}" - ) - self._channel_refresh_tasks.append(refresh_task) + self._channel_refresh_task = asyncio.create_task(self._manage_channel()) + if sys.version_info >= (3, 8): + # task names supported in Python 3.8+ + self._channel_refresh_task.set_name( + f"{self.__class__.__name__} channel refresh" + ) async def close(self, timeout: float = 2.0): """ Cancel all background tasks """ - for task in self._channel_refresh_tasks: - task.cancel() - group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True) - await asyncio.wait_for(group, timeout=timeout) + if self._channel_refresh_task: + self._channel_refresh_task.cancel() + await asyncio.wait_for(self._channel_refresh_task, timeout=timeout) await self.transport.close() - self._channel_refresh_tasks = [] + self._channel_refresh_task = None async def _ping_and_warm_instances( self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None @@ -282,7 +261,6 @@ async def _ping_and_warm_instances( async def _manage_channel( self, - channel_idx: int, refresh_interval_min: float = 60 * 35, refresh_interval_max: float = 60 * 45, grace_period: float = 60 * 10, @@ -296,7 +274,6 @@ async def _manage_channel( Runs continuously until the client is closed Args: - channel_idx: index of the channel in the transport's channel pool refresh_interval_min: minimum interval before initiating refresh process in seconds. Actual interval will be a random value between `refresh_interval_min` and `refresh_interval_max` @@ -312,19 +289,18 @@ async def _manage_channel( next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: # warm the current channel immediately - channel = self.transport.channels[channel_idx] - await self._ping_and_warm_instances(channel) + await self._ping_and_warm_instances(self.transport._grpc_channel) # continuously refresh the channel every `refresh_interval` seconds while True: await asyncio.sleep(next_sleep) + start_timestamp = time.time() # prepare new channel for use - new_channel = self.transport.grpc_channel._create_channel() + old_channel = self.transport._grpc_channel + new_channel = self.transport_create_channel() await self._ping_and_warm_instances(new_channel) # cycle channel out of use, with long grace window before closure - start_timestamp = time.time() - await self.transport.replace_channel( - channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel - ) + self.transport._grpc_channel = new_channel + await old_channel.close(grace_period) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.time() - start_timestamp) @@ -333,9 +309,8 @@ async def _register_instance( self, instance_id: str, owner: Union[TableAsync, ExecuteQueryIteratorAsync] ) -> None: """ - Registers an instance with the client, and warms the channel pool - for the instance - The client will periodically refresh grpc channel pool used to make + Registers an instance with the client, and warms the channel for the instance + The client will periodically refresh grpc channel used to make requests, and new channels will be warmed for each registered instance Channels will not be refreshed unless at least one instance is registered @@ -352,11 +327,10 @@ async def _register_instance( self._instance_owners.setdefault(instance_key, set()).add(id(owner)) if instance_name not in self._active_instances: self._active_instances.add(instance_key) - if self._channel_refresh_tasks: + if self._channel_refresh_task: # refresh tasks already running # call ping and warm on all existing channels - for channel in self.transport.channels: - await self._ping_and_warm_instances(channel, instance_key) + await self._ping_and_warm_instances(instance_key) else: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() From 7bd5259e79ad06d43247a23ecd1d99894cf7be92 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Nov 2024 12:03:00 -0700 Subject: [PATCH 04/14] fixed some tests --- tests/unit/data/_async/test_client.py | 114 ++++---------------------- 1 file changed, 15 insertions(+), 99 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 6c49ca0da..32338a817 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -51,7 +51,7 @@ def _make_client(*args, use_emulator=True, **kwargs): env_mask = {} # by default, use emulator mode to avoid auth issues in CI - # emulator mode must be disabled by tests that check channel pooling/refresh background tasks + # emulator mode must be disabled by tests that check refresh background tasks if use_emulator: env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" else: @@ -74,19 +74,16 @@ def _make_one(self, *args, **kwargs): @pytest.mark.asyncio async def test_ctor(self): expected_project = "project-id" - expected_pool_size = 11 expected_credentials = AnonymousCredentials() client = self._make_one( project="project-id", - pool_size=expected_pool_size, credentials=expected_credentials, use_emulator=False, ) await asyncio.sleep(0) assert client.project == expected_project - assert len(client.transport._grpc_channel._pool) == expected_pool_size assert not client._active_instances - assert len(client._channel_refresh_tasks) == expected_pool_size + assert client._channel_refresh_tasks is not None assert client.transport._credentials == expected_credentials await client.close() @@ -99,11 +96,9 @@ async def test_ctor_super_inits(self): from google.api_core import client_options as client_options_lib project = "project-id" - pool_size = 11 credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - transport_str = f"pooled_grpc_asyncio_{pool_size}" with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( @@ -113,7 +108,6 @@ async def test_ctor_super_inits(self): try: self._make_one( project=project, - pool_size=pool_size, credentials=credentials, client_options=options_parsed, use_emulator=False, @@ -179,78 +173,6 @@ async def test_veneer_grpc_headers(self): ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" await client.close() - @pytest.mark.asyncio - async def test_channel_pool_creation(self): - pool_size = 14 - with mock.patch( - "google.api_core.grpc_helpers_async.create_channel" - ) as create_channel: - create_channel.return_value = AsyncMock() - client = self._make_one(project="project-id", pool_size=pool_size) - assert create_channel.call_count == pool_size - await client.close() - # channels should be unique - client = self._make_one(project="project-id", pool_size=pool_size) - pool_list = list(client.transport._grpc_channel._pool) - pool_set = set(client.transport._grpc_channel._pool) - assert len(pool_list) == len(pool_set) - await client.close() - - @pytest.mark.asyncio - async def test_channel_pool_rotation(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel, - ) - - pool_size = 7 - - with mock.patch.object(PooledChannel, "next_channel") as next_channel: - client = self._make_one(project="project-id", pool_size=pool_size) - assert len(client.transport._grpc_channel._pool) == pool_size - next_channel.reset_mock() - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "unary_unary" - ) as unary_unary: - # calling an rpc `pool_size` times should use a different channel each time - channel_next = None - for i in range(pool_size): - channel_last = channel_next - channel_next = client.transport.grpc_channel._pool[i] - assert channel_last != channel_next - next_channel.return_value = channel_next - client.transport.ping_and_warm() - assert next_channel.call_count == i + 1 - unary_unary.assert_called_once() - unary_unary.reset_mock() - await client.close() - - @pytest.mark.asyncio - async def test_channel_pool_replace(self): - with mock.patch.object(asyncio, "sleep"): - pool_size = 7 - client = self._make_one(project="project-id", pool_size=pool_size) - for replace_idx in range(pool_size): - start_pool = [ - channel for channel in client.transport._grpc_channel._pool - ] - grace_period = 9 - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "close" - ) as close: - new_channel = grpc.aio.insecure_channel("localhost:8080") - await client.transport.replace_channel( - replace_idx, grace=grace_period, new_channel=new_channel - ) - close.assert_called_once_with(grace=grace_period) - close.assert_awaited_once() - assert client.transport._grpc_channel._pool[replace_idx] == new_channel - for i in range(pool_size): - if i != replace_idx: - assert client.transport._grpc_channel._pool[i] == start_pool[i] - else: - assert client.transport._grpc_channel._pool[i] != start_pool[i] - await client.close() - @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context @@ -259,48 +181,42 @@ def test__start_background_channel_refresh_sync(self): client._start_background_channel_refresh() @pytest.mark.asyncio - async def test__start_background_channel_refresh_tasks_exist(self): + async def test__start_background_channel_refresh_task_exists(self): # if tasks exist, should do nothing client = self._make_one(project="project-id", use_emulator=False) - assert len(client._channel_refresh_tasks) > 0 + assert client._channel_refresh_task is not None with mock.patch.object(asyncio, "create_task") as create_task: client._start_background_channel_refresh() create_task.assert_not_called() await client.close() @pytest.mark.asyncio - @pytest.mark.parametrize("pool_size", [1, 3, 7]) - async def test__start_background_channel_refresh(self, pool_size): + async def test__start_background_channel_refresh(self): # should create background tasks for each channel client = self._make_one( - project="project-id", pool_size=pool_size, use_emulator=False + project="project-id", use_emulator=False ) ping_and_warm = AsyncMock() client._ping_and_warm_instances = ping_and_warm client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - assert isinstance(task, asyncio.Task) + assert client._channel_refresh_task is not None + assert isinstance(client._channel_refresh_task, asyncio.Task) await asyncio.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) + assert ping_and_warm.call_count == 1 + ping_and_warm.assert_any_call(client.transport._grpc_channel) await client.close() @pytest.mark.asyncio @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" ) - async def test__start_background_channel_refresh_tasks_names(self): + async def test__start_background_channel_refresh_task_names(self): # if tasks exist, should do nothing - pool_size = 3 client = self._make_one( - project="project-id", pool_size=pool_size, use_emulator=False + project="project-id", use_emulator=False ) - for i in range(pool_size): - name = client._channel_refresh_tasks[i].get_name() - assert str(i) in name - assert "BigtableDataClientAsync channel refresh " in name + name = client._channel_refresh_task.get_name() + assert "BigtableDataClientAsync channel refresh" in name await client.close() @pytest.mark.asyncio @@ -594,7 +510,7 @@ async def test__register_instance(self): instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] + client_mock._channel_refresh_task is None client_mock._start_background_channel_refresh.side_effect = ( lambda: client_mock._channel_refresh_tasks.append(mock.Mock) ) From 852128787eeea05f488677422373e13e700b10f1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Nov 2024 12:47:09 -0700 Subject: [PATCH 05/14] removed channel from _ping_and_warm_instances --- google/cloud/bigtable/data/_async/client.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 076070f35..c60ff4144 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -220,7 +220,7 @@ async def close(self, timeout: float = 2.0): self._channel_refresh_task = None async def _ping_and_warm_instances( - self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None + self, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -228,7 +228,6 @@ async def _ping_and_warm_instances( Pings each Bigtable instance registered in `_active_instances` on the client Args: - channel: grpc channel to warm instance_key: if provided, only warm the instance associated with the key Returns: list[BaseException | None]: sequence of results or exceptions from the ping requests @@ -236,7 +235,7 @@ async def _ping_and_warm_instances( instance_list = ( [instance_key] if instance_key is not None else self._active_instances ) - ping_rpc = channel.unary_unary( + ping_rpc = self.transport._grpc_channel.unary_unary( "/google.bigtable.v2.Bigtable/PingAndWarm", request_serializer=PingAndWarmRequest.serialize, ) @@ -289,7 +288,7 @@ async def _manage_channel( next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: # warm the current channel immediately - await self._ping_and_warm_instances(self.transport._grpc_channel) + await self._ping_and_warm_instances() # continuously refresh the channel every `refresh_interval` seconds while True: await asyncio.sleep(next_sleep) @@ -297,7 +296,7 @@ async def _manage_channel( # prepare new channel for use old_channel = self.transport._grpc_channel new_channel = self.transport_create_channel() - await self._ping_and_warm_instances(new_channel) + await self._ping_and_warm_instances() # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel await old_channel.close(grace_period) From 6eedc9c549b4ea8925941bf652e7711924a292bb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Nov 2024 13:24:04 -0700 Subject: [PATCH 06/14] fixed lint issues --- google/cloud/bigtable/data/_async/client.py | 21 +++++++++++++-------- tests/unit/data/_async/test_client.py | 11 +++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index c60ff4144..b348057c5 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -88,8 +88,9 @@ DEFAULT_CLIENT_INFO, BigtableAsyncClient, ) -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcAsyncIOTransport +from google.cloud.bigtable_v2.services.bigtable.transports import ( + BigtableGrpcAsyncIOTransport, +) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest if TYPE_CHECKING: @@ -105,7 +106,7 @@ def __init__( client_options: dict[str, Any] | "google.api_core.client_options.ClientOptions" | None = None, - **kwargs + **kwargs, ): """ Create a client instance for the Bigtable Data API @@ -163,9 +164,13 @@ def __init__( credentials=credentials, client_options=client_options, client_info=client_info, - transport=lambda *args, **kwargs: BigtableGrpcAsyncIOTransport(*args, **kwargs, channel=custom_channel) + transport=lambda *args, **kwargs: BigtableGrpcAsyncIOTransport( + *args, **kwargs, channel=custom_channel + ), + ) + self.transport = cast( + BigtableGrpcAsyncIOTransport, self._gapic_client.transport ) - self.transport = self._gapic_client.transport # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance @@ -235,7 +240,7 @@ async def _ping_and_warm_instances( instance_list = ( [instance_key] if instance_key is not None else self._active_instances ) - ping_rpc = self.transport._grpc_channel.unary_unary( + ping_rpc = self.transport.grpc_channel.unary_unary( "/google.bigtable.v2.Bigtable/PingAndWarm", request_serializer=PingAndWarmRequest.serialize, ) @@ -294,8 +299,8 @@ async def _manage_channel( await asyncio.sleep(next_sleep) start_timestamp = time.time() # prepare new channel for use - old_channel = self.transport._grpc_channel - new_channel = self.transport_create_channel() + old_channel = self.transport.grpc_channel + new_channel = self.transport.create_channel() await self._ping_and_warm_instances() # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 32338a817..c6ad1995f 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -117,7 +117,6 @@ async def test_ctor_super_inits(self): # test gapic superclass init was called assert bigtable_client_init.call_count == 1 kwargs = bigtable_client_init.call_args[1] - assert kwargs["transport"] == transport_str assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed # test mixin superclass init was called @@ -193,9 +192,7 @@ async def test__start_background_channel_refresh_task_exists(self): @pytest.mark.asyncio async def test__start_background_channel_refresh(self): # should create background tasks for each channel - client = self._make_one( - project="project-id", use_emulator=False - ) + client = self._make_one(project="project-id", use_emulator=False) ping_and_warm = AsyncMock() client._ping_and_warm_instances = ping_and_warm client._start_background_channel_refresh() @@ -203,7 +200,7 @@ async def test__start_background_channel_refresh(self): assert isinstance(client._channel_refresh_task, asyncio.Task) await asyncio.sleep(0.1) assert ping_and_warm.call_count == 1 - ping_and_warm.assert_any_call(client.transport._grpc_channel) + ping_and_warm.assert_any_call(client.transport.grpc_channel) await client.close() @pytest.mark.asyncio @@ -212,9 +209,7 @@ async def test__start_background_channel_refresh(self): ) async def test__start_background_channel_refresh_task_names(self): # if tasks exist, should do nothing - client = self._make_one( - project="project-id", use_emulator=False - ) + client = self._make_one(project="project-id", use_emulator=False) name = client._channel_refresh_task.get_name() assert "BigtableDataClientAsync channel refresh" in name await client.close() From 72c181c7d6f704af7159cc2b2cf9acf3eba08706 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 16:00:12 -0800 Subject: [PATCH 07/14] catch cancellation exception --- google/cloud/bigtable/data/_async/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index b348057c5..cd95f67d6 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -220,7 +220,10 @@ async def close(self, timeout: float = 2.0): """ if self._channel_refresh_task: self._channel_refresh_task.cancel() - await asyncio.wait_for(self._channel_refresh_task, timeout=timeout) + try: + await asyncio.wait_for(self._channel_refresh_task, timeout=timeout) + except asyncio.CancelledError: + pass await self.transport.close() self._channel_refresh_task = None From d9acce1a3ff6736a10171ae7bc1746d6b513ec1e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 16:00:41 -0800 Subject: [PATCH 08/14] re-enable warming external channels --- google/cloud/bigtable/data/_async/client.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index cd95f67d6..0f747725d 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -228,7 +228,7 @@ async def close(self, timeout: float = 2.0): self._channel_refresh_task = None async def _ping_and_warm_instances( - self, instance_key: _WarmedInstanceKey | None = None + self, instance_key: _WarmedInstanceKey | None = None, channel: grpc.aio.Channel|None=None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -237,13 +237,15 @@ async def _ping_and_warm_instances( Args: instance_key: if provided, only warm the instance associated with the key + channel: grpc channel to warm. If none, warms `self.transport.grpc_channel` Returns: list[BaseException | None]: sequence of results or exceptions from the ping requests """ + channel = channel or self.transport.grpc_channel instance_list = ( [instance_key] if instance_key is not None else self._active_instances ) - ping_rpc = self.transport.grpc_channel.unary_unary( + ping_rpc = channel.unary_unary( "/google.bigtable.v2.Bigtable/PingAndWarm", request_serializer=PingAndWarmRequest.serialize, ) @@ -296,7 +298,7 @@ async def _manage_channel( next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: # warm the current channel immediately - await self._ping_and_warm_instances() + await self._ping_and_warm_instances(channel=self.transport.grpc_channel) # continuously refresh the channel every `refresh_interval` seconds while True: await asyncio.sleep(next_sleep) @@ -304,7 +306,7 @@ async def _manage_channel( # prepare new channel for use old_channel = self.transport.grpc_channel new_channel = self.transport.create_channel() - await self._ping_and_warm_instances() + await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel await old_channel.close(grace_period) From 249d688cd2cc20199508962b914bd75a802f320d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 16:01:35 -0800 Subject: [PATCH 09/14] fixed tests --- tests/system/data/test_execute_query_utils.py | 3 +- tests/system/data/test_system.py | 7 +- tests/unit/data/_async/test_client.py | 181 ++++++------------ 3 files changed, 59 insertions(+), 132 deletions(-) diff --git a/tests/system/data/test_execute_query_utils.py b/tests/system/data/test_execute_query_utils.py index 9e27b95f2..bca660fe8 100644 --- a/tests/system/data/test_execute_query_utils.py +++ b/tests/system/data/test_execute_query_utils.py @@ -14,7 +14,6 @@ from unittest import mock -import google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio as pga from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue import grpc.aio @@ -143,7 +142,7 @@ def unary_stream(self, *args, **kwargs): return mock.MagicMock() -class ChannelMockAsync(pga.PooledChannel, mock.MagicMock): +class ChannelMockAsync(grpc.aio.Channel, mock.MagicMock): def __init__(self, *args, **kwargs): mock.MagicMock.__init__(self, *args, **kwargs) self.execute_query_calls = [] diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 9fe208551..8f31827ed 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -168,12 +168,7 @@ async def test_ping_and_warm(client, table): """ Test ping and warm from handwritten client """ - try: - channel = client.transport._grpc_channel.pool[0] - except Exception: - # for sync client - channel = client.transport._grpc_channel - results = await client._ping_and_warm_instances(channel) + results = await client._ping_and_warm_instances() assert len(results) == 1 assert results[0] is None diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index c6ad1995f..14dffe66c 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -83,7 +83,7 @@ async def test_ctor(self): await asyncio.sleep(0) assert client.project == expected_project assert not client._active_instances - assert client._channel_refresh_tasks is not None + assert client._channel_refresh_task is not None assert client.transport._credentials == expected_credentials await client.close() @@ -200,7 +200,6 @@ async def test__start_background_channel_refresh(self): assert isinstance(client._channel_refresh_task, asyncio.Task) await asyncio.sleep(0.1) assert ping_and_warm.call_count == 1 - ping_and_warm.assert_any_call(client.transport.grpc_channel) await client.close() @pytest.mark.asyncio @@ -227,7 +226,7 @@ async def test__ping_and_warm_instances(self): # test with no instances client_mock._active_instances = [] result = await self._get_target_class()._ping_and_warm_instances( - client_mock, channel + client_mock, channel=channel ) assert len(result) == 0 gather.assert_called_once() @@ -241,7 +240,7 @@ async def test__ping_and_warm_instances(self): gather.reset_mock() channel.reset_mock() result = await self._get_target_class()._ping_and_warm_instances( - client_mock, channel + client_mock, channel=channel ) assert len(result) == 4 gather.assert_called_once() @@ -275,17 +274,16 @@ async def test__ping_and_warm_single_instance(self): with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: # simulate gather by returning the same number of items as passed in gather.side_effect = lambda *args, **kwargs: [None for _ in args] - channel = mock.Mock() # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 test_key = ("test-instance", "test-table", "test-app-profile") result = await self._get_target_class()._ping_and_warm_instances( - client_mock, channel, test_key + client_mock, test_key ) # should only have been called with test instance assert len(result) == 1 # check grpc call arguments - grpc_call_args = channel.unary_unary().call_args_list + grpc_call_args = client_mock.transport.grpc_channel.unary_unary().call_args_list assert len(grpc_call_args) == 1 kwargs = grpc_call_args[0][1] request = kwargs["request"] @@ -323,7 +321,7 @@ async def test__manage_channel_first_sleep( try: client = self._make_one(project="project-id") client._channel_init_time = -wait_time - await client._manage_channel(0, refresh_interval, refresh_interval) + await client._manage_channel(refresh_interval, refresh_interval) except asyncio.CancelledError: pass sleep.assert_called_once() @@ -342,40 +340,27 @@ async def test__manage_channel_ping_and_warm(self): client_mock = mock.Mock() client_mock._channel_init_time = time.monotonic() - channel_list = [mock.Mock(), mock.Mock()] - client_mock.transport.channels = channel_list - new_channel = mock.Mock() - client_mock.transport.grpc_channel._create_channel.return_value = new_channel + orig_channel = client_mock.transport.grpc_channel # should ping an warm all new channels, and old channels if sleeping with mock.patch.object(asyncio, "sleep"): - # stop process after replace_channel is called - client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + # stop process after close is called + orig_channel.close.side_effect = asyncio.CancelledError ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() # should ping and warm old channel then new if sleep > 0 try: - channel_idx = 1 await self._get_target_class()._manage_channel( - client_mock, channel_idx, 10 + client_mock, 10 ) except asyncio.CancelledError: pass # should have called at loop start, and after replacement assert ping_and_warm.call_count == 2 # should have replaced channel once - assert client_mock.transport.replace_channel.call_count == 1 + assert client_mock.transport._grpc_channel != orig_channel # make sure new and old channels were warmed - old_channel = channel_list[channel_idx] - assert old_channel != new_channel - called_with = [call[0][0] for call in ping_and_warm.call_args_list] - assert old_channel in called_with - assert new_channel in called_with - # should ping and warm instantly new channel only if not sleeping - ping_and_warm.reset_mock() - try: - await self._get_target_class()._manage_channel(client_mock, 0, 0, 0) - except asyncio.CancelledError: - pass - ping_and_warm.assert_called_once_with(new_channel) + called_with = [call[1]["channel"] for call in ping_and_warm.call_args_list] + assert orig_channel in called_with + assert client_mock.transport.grpc_channel in called_with @pytest.mark.asyncio @pytest.mark.parametrize( @@ -393,7 +378,6 @@ async def test__manage_channel_sleeps( import time import random - channel_idx = 1 with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time: @@ -406,10 +390,10 @@ async def test__manage_channel_sleeps( client = self._make_one(project="project-id") if refresh_interval is not None: await client._manage_channel( - channel_idx, refresh_interval, refresh_interval + refresh_interval, refresh_interval ) else: - await client._manage_channel(channel_idx) + await client._manage_channel() except asyncio.CancelledError: pass assert sleep.call_count == num_cycles @@ -428,7 +412,7 @@ async def test__manage_channel_random(self): uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_one(project="project-id", pool_size=1) + client = self._make_one(project="project-id") except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -436,9 +420,9 @@ async def test__manage_channel_random(self): min_val = 200 max_val = 205 uniform.side_effect = lambda min_, max_: min_ - sleep.side_effect = [None, None, asyncio.CancelledError] + sleep.side_effect = [None, asyncio.CancelledError] try: - await client._manage_channel(0, min_val, max_val) + await client._manage_channel(min_val, max_val) except asyncio.CancelledError: pass assert uniform.call_count == 2 @@ -451,47 +435,33 @@ async def test__manage_channel_random(self): @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) from google.api_core import grpc_helpers_async expected_grace = 9 expected_refresh = 0.5 - channel_idx = 1 new_channel = grpc.aio.insecure_channel("localhost:8080") - with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "replace_channel" - ) as replace_channel: - with mock.patch.object(asyncio, "sleep") as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError - ] - with mock.patch.object( - grpc_helpers_async, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - client = self._make_one(project="project-id", use_emulator=False) - create_channel.reset_mock() - try: - await client._manage_channel( - channel_idx, - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=expected_grace, - ) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - assert replace_channel.call_count == num_cycles - for call in replace_channel.call_args_list: - args, kwargs = call - assert args[0] == channel_idx - assert kwargs["grace"] == expected_grace - assert kwargs["new_channel"] == new_channel - await client.close() + with mock.patch.object(asyncio, "sleep") as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [ + asyncio.CancelledError + ] + with mock.patch.object( + grpc_helpers_async, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + client = self._make_one(project="project-id", use_emulator=False) + create_channel.reset_mock() + try: + await client._manage_channel( + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=expected_grace, + ) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + await client.close() @pytest.mark.asyncio async def test__register_instance(self): @@ -505,12 +475,7 @@ async def test__register_instance(self): instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_task is None - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels + client_mock._channel_refresh_task = None client_mock._ping_and_warm_instances = AsyncMock() table_mock = mock.Mock() await self._get_target_class()._register_instance( @@ -528,21 +493,17 @@ async def test__register_instance(self): assert expected_key == tuple(list(active_instances)[0]) assert len(instance_owners) == 1 assert expected_key == tuple(list(instance_owners)[0]) - # should be a new task set - assert client_mock._channel_refresh_tasks + # simulate creation of refresh task + client_mock._channel_refresh_task = mock.Mock() # next call should not call _start_background_channel_refresh again table_mock2 = mock.Mock() await self._get_target_class()._register_instance( client_mock, "instance-2", table_mock2 ) assert client_mock._start_background_channel_refresh.call_count == 1 + assert client_mock._ping_and_warm_instances.call_args[0][0][0] == "prefix/instance-2" # but it should call ping and warm with new instance key - assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) - for channel in mock_channels: - assert channel in [ - call[0][0] - for call in client_mock._ping_and_warm_instances.call_args_list - ] + assert client_mock._ping_and_warm_instances.call_count == 1 # check for updated lists assert len(active_instances) == 2 assert len(instance_owners) == 2 @@ -849,60 +810,33 @@ async def test_get_table_context_manager(self): assert client._instance_owners[instance_key] == {id(table)} assert close_mock.call_count == 1 - @pytest.mark.asyncio - async def test_multiple_pool_sizes(self): - # should be able to create multiple clients with different pool sizes without issue - pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] - for pool_size in pool_sizes: - client = self._make_one( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_one( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client_duplicate._channel_refresh_tasks) == pool_size - assert str(pool_size) in str(client.transport) - await client.close() - await client_duplicate.close() - @pytest.mark.asyncio async def test_close(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - - pool_size = 7 client = self._make_one( - project="project-id", pool_size=pool_size, use_emulator=False + project="project-id", use_emulator=False ) - assert len(client._channel_refresh_tasks) == pool_size - tasks_list = list(client._channel_refresh_tasks) - for task in client._channel_refresh_tasks: - assert not task.done() + task = client._channel_refresh_task + assert task is not None + assert not task.done() with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() + client.transport, "close", AsyncMock() ) as close_mock: await client.close() close_mock.assert_called_once() close_mock.assert_awaited() - for task in tasks_list: - assert task.done() - assert task.cancelled() - assert client._channel_refresh_tasks == [] + assert task.done() + assert task.cancelled() + assert client._channel_refresh_task is None @pytest.mark.asyncio async def test_close_with_timeout(self): - pool_size = 7 expected_timeout = 19 - client = self._make_one(project="project-id", pool_size=pool_size) - tasks = list(client._channel_refresh_tasks) + client = self._make_one(project="project-id", use_emulator=False) with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() wait_for_mock.assert_awaited() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout - client._channel_refresh_tasks = tasks await client.close() @pytest.mark.asyncio @@ -910,11 +844,10 @@ async def test_context_manager(self): # context manager should close the client cleanly close_mock = AsyncMock() true_close = None - async with self._make_one(project="project-id") as client: + async with self._make_one(project="project-id", use_emulator=False) as client: true_close = client.close() client.close = close_mock - for task in client._channel_refresh_tasks: - assert not task.done() + assert not client._channel_refresh_task.done() assert client.project == "project-id" assert client._active_instances == set() close_mock.assert_not_called() @@ -935,7 +868,7 @@ def test_client_ctor_sync(self): in str(expected_warning[0].message) ) assert client.project == "project-id" - assert client._channel_refresh_tasks == [] + assert client._channel_refresh_task is None class TestTableAsync: From 66d7d1cc9595a325c3c27c9c977d4fb46eb86887 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 16:03:56 -0800 Subject: [PATCH 10/14] fixed lint --- google/cloud/bigtable/data/_async/client.py | 4 +++- tests/unit/data/_async/test_client.py | 21 ++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 0f747725d..4765a3cf2 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -228,7 +228,9 @@ async def close(self, timeout: float = 2.0): self._channel_refresh_task = None async def _ping_and_warm_instances( - self, instance_key: _WarmedInstanceKey | None = None, channel: grpc.aio.Channel|None=None + self, + instance_key: _WarmedInstanceKey | None = None, + channel: grpc.aio.Channel | None = None, ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 14dffe66c..be34d344a 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -283,7 +283,9 @@ async def test__ping_and_warm_single_instance(self): # should only have been called with test instance assert len(result) == 1 # check grpc call arguments - grpc_call_args = client_mock.transport.grpc_channel.unary_unary().call_args_list + grpc_call_args = ( + client_mock.transport.grpc_channel.unary_unary().call_args_list + ) assert len(grpc_call_args) == 1 kwargs = grpc_call_args[0][1] request = kwargs["request"] @@ -348,9 +350,7 @@ async def test__manage_channel_ping_and_warm(self): ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() # should ping and warm old channel then new if sleep > 0 try: - await self._get_target_class()._manage_channel( - client_mock, 10 - ) + await self._get_target_class()._manage_channel(client_mock, 10) except asyncio.CancelledError: pass # should have called at loop start, and after replacement @@ -501,7 +501,10 @@ async def test__register_instance(self): client_mock, "instance-2", table_mock2 ) assert client_mock._start_background_channel_refresh.call_count == 1 - assert client_mock._ping_and_warm_instances.call_args[0][0][0] == "prefix/instance-2" + assert ( + client_mock._ping_and_warm_instances.call_args[0][0][0] + == "prefix/instance-2" + ) # but it should call ping and warm with new instance key assert client_mock._ping_and_warm_instances.call_count == 1 # check for updated lists @@ -812,15 +815,11 @@ async def test_get_table_context_manager(self): @pytest.mark.asyncio async def test_close(self): - client = self._make_one( - project="project-id", use_emulator=False - ) + client = self._make_one(project="project-id", use_emulator=False) task = client._channel_refresh_task assert task is not None assert not task.done() - with mock.patch.object( - client.transport, "close", AsyncMock() - ) as close_mock: + with mock.patch.object(client.transport, "close", AsyncMock()) as close_mock: await client.close() close_mock.assert_called_once() close_mock.assert_awaited() From fcb12987306d274fd135da553d1299c08e76cbaa Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 16:55:36 -0800 Subject: [PATCH 11/14] fixewd execute_query tests --- tests/system/data/test_execute_query_async.py | 6 +---- tests/system/data/test_execute_query_utils.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tests/system/data/test_execute_query_async.py b/tests/system/data/test_execute_query_async.py index a680d2de0..99dd60b29 100644 --- a/tests/system/data/test_execute_query_async.py +++ b/tests/system/data/test_execute_query_async.py @@ -39,11 +39,7 @@ def async_channel_mock(self): def async_client(self, async_channel_mock): with mock.patch.dict( os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"} - ), mock.patch.object( - google.cloud.bigtable.data._async.client, - "PooledChannel", - return_value=async_channel_mock, - ): + ), mock.patch("grpc.aio.insecure_channel", return_value=async_channel_mock): yield BigtableDataClientAsync() @pytest.mark.asyncio diff --git a/tests/system/data/test_execute_query_utils.py b/tests/system/data/test_execute_query_utils.py index bca660fe8..3439e04d2 100644 --- a/tests/system/data/test_execute_query_utils.py +++ b/tests/system/data/test_execute_query_utils.py @@ -269,3 +269,27 @@ def wait_for_connection(*args, **kwargs): # PTAL https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.Channel.unary_stream return UnaryStreamMultiCallableMock(self) return async_mock() + + def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: + raise NotImplementedError() + + def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: + raise NotImplementedError() + + async def close(self, grace=None): + return + + async def channel_ready(self): + return + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + raise NotImplementedError() + + async def wait_for_state_change(self, last_observed_state): + raise NotImplementedError() From 89c660c4d60ba236ca56fc758819f5072563cccb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 16:57:55 -0800 Subject: [PATCH 12/14] fixed lint --- tests/system/data/test_execute_query_async.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/system/data/test_execute_query_async.py b/tests/system/data/test_execute_query_async.py index 99dd60b29..489dfeab6 100644 --- a/tests/system/data/test_execute_query_async.py +++ b/tests/system/data/test_execute_query_async.py @@ -23,7 +23,6 @@ ) from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data import BigtableDataClientAsync -import google.cloud.bigtable.data._async.client TABLE_NAME = "TABLE_NAME" INSTANCE_NAME = "INSTANCE_NAME" From 0ebee471d578bea6d5ff62d6607ac8cb0ca8a036 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 17:07:33 -0800 Subject: [PATCH 13/14] mock out channel creation --- tests/unit/data/_async/test_client.py | 40 ++++++++++++++------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index be34d344a..bfe27f00d 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -388,12 +388,13 @@ async def test__manage_channel_sleeps( ] try: client = self._make_one(project="project-id") - if refresh_interval is not None: - await client._manage_channel( - refresh_interval, refresh_interval - ) - else: - await client._manage_channel() + with mock.patch.object(client.transport, "create_channel"): + if refresh_interval is not None: + await client._manage_channel( + refresh_interval, refresh_interval + ) + else: + await client._manage_channel() except asyncio.CancelledError: pass assert sleep.call_count == num_cycles @@ -417,19 +418,20 @@ async def test__manage_channel_random(self): uniform.side_effect = None uniform.reset_mock() sleep.reset_mock() - min_val = 200 - max_val = 205 - uniform.side_effect = lambda min_, max_: min_ - sleep.side_effect = [None, asyncio.CancelledError] - try: - await client._manage_channel(min_val, max_val) - except asyncio.CancelledError: - pass - assert uniform.call_count == 2 - uniform_args = [call[0] for call in uniform.call_args_list] - for found_min, found_max in uniform_args: - assert found_min == min_val - assert found_max == max_val + with mock.patch.object(client.transport, "create_channel"): + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, asyncio.CancelledError] + try: + await client._manage_channel(min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 2 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val @pytest.mark.asyncio @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) From 81322c806296e0f977530fb46dc78865bd33c3a6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 5 Nov 2024 17:16:18 -0800 Subject: [PATCH 14/14] improved channel mocking --- tests/unit/data/_async/test_client.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index bfe27f00d..07c56d230 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -378,6 +378,8 @@ async def test__manage_channel_sleeps( import time import random + channel = mock.Mock() + channel.close = mock.AsyncMock() with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time: @@ -388,7 +390,10 @@ async def test__manage_channel_sleeps( ] try: client = self._make_one(project="project-id") - with mock.patch.object(client.transport, "create_channel"): + client.transport._grpc_channel = channel + with mock.patch.object( + client.transport, "create_channel", return_value=channel + ): if refresh_interval is not None: await client._manage_channel( refresh_interval, refresh_interval