diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index c60ff4144..b348057c5 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -88,8 +88,9 @@ DEFAULT_CLIENT_INFO, BigtableAsyncClient, ) -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcAsyncIOTransport +from google.cloud.bigtable_v2.services.bigtable.transports import ( + BigtableGrpcAsyncIOTransport, +) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest if TYPE_CHECKING: @@ -105,7 +106,7 @@ def __init__( client_options: dict[str, Any] | "google.api_core.client_options.ClientOptions" | None = None, - **kwargs + **kwargs, ): """ Create a client instance for the Bigtable Data API @@ -163,9 +164,13 @@ def __init__( credentials=credentials, client_options=client_options, client_info=client_info, - transport=lambda *args, **kwargs: BigtableGrpcAsyncIOTransport(*args, **kwargs, channel=custom_channel) + transport=lambda *args, **kwargs: BigtableGrpcAsyncIOTransport( + *args, **kwargs, channel=custom_channel + ), + ) + self.transport = cast( + BigtableGrpcAsyncIOTransport, self._gapic_client.transport ) - self.transport = self._gapic_client.transport # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance @@ -235,7 +240,7 @@ async def _ping_and_warm_instances( instance_list = ( [instance_key] if instance_key is not None else self._active_instances ) - ping_rpc = self.transport._grpc_channel.unary_unary( + ping_rpc = self.transport.grpc_channel.unary_unary( "/google.bigtable.v2.Bigtable/PingAndWarm", request_serializer=PingAndWarmRequest.serialize, ) @@ -294,8 +299,8 @@ async def _manage_channel( await asyncio.sleep(next_sleep) start_timestamp = time.time() # prepare new channel for use - old_channel = self.transport._grpc_channel - new_channel = self.transport_create_channel() + old_channel = self.transport.grpc_channel + new_channel = self.transport.create_channel() await self._ping_and_warm_instances() # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 32338a817..c6ad1995f 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -117,7 +117,6 @@ async def test_ctor_super_inits(self): # test gapic superclass init was called assert bigtable_client_init.call_count == 1 kwargs = bigtable_client_init.call_args[1] - assert kwargs["transport"] == transport_str assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed # test mixin superclass init was called @@ -193,9 +192,7 @@ async def test__start_background_channel_refresh_task_exists(self): @pytest.mark.asyncio async def test__start_background_channel_refresh(self): # should create background tasks for each channel - client = self._make_one( - project="project-id", use_emulator=False - ) + client = self._make_one(project="project-id", use_emulator=False) ping_and_warm = AsyncMock() client._ping_and_warm_instances = ping_and_warm client._start_background_channel_refresh() @@ -203,7 +200,7 @@ async def test__start_background_channel_refresh(self): assert isinstance(client._channel_refresh_task, asyncio.Task) await asyncio.sleep(0.1) assert ping_and_warm.call_count == 1 - ping_and_warm.assert_any_call(client.transport._grpc_channel) + ping_and_warm.assert_any_call(client.transport.grpc_channel) await client.close() @pytest.mark.asyncio @@ -212,9 +209,7 @@ async def test__start_background_channel_refresh(self): ) async def test__start_background_channel_refresh_task_names(self): # if tasks exist, should do nothing - client = self._make_one( - project="project-id", use_emulator=False - ) + client = self._make_one(project="project-id", use_emulator=False) name = client._channel_refresh_task.get_name() assert "BigtableDataClientAsync channel refresh" in name await client.close()