diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 689202267..a81178ea3 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -365,8 +365,14 @@ async def _manage_channel( await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel - await old_channel.close(grace_period) - # subtract the time spent waiting for the channel to be replaced + # give old_channel a chance to complete existing rpcs + if CrossSync.is_async: + await old_channel.close(grace_period) + else: + if grace_period: + self._is_closed.wait(grace_period) + old_channel.close() + # subtract thed time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 2e18d83ab..1208d55a3 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -213,15 +213,15 @@ async def test__start_background_channel_refresh_task_exists(self): @CrossSync.pytest async def test__start_background_channel_refresh(self): # should create background tasks for each channel - client = self._make_client(project="project-id", use_emulator=False) - ping_and_warm = CrossSync.Mock() - client._ping_and_warm_instances = ping_and_warm - client._start_background_channel_refresh() - assert client._channel_refresh_task is not None - assert isinstance(client._channel_refresh_task, asyncio.Task) - await asyncio.sleep(0.1) - assert ping_and_warm.call_count == 1 - await client.close() + client = self._make_client(project="project-id") + with mock.patch.object(client, "_ping_and_warm_instances", CrossSync.Mock()) as ping_and_warm: + client._emulator_host = None + client._start_background_channel_refresh() + assert client._channel_refresh_task is not None + assert isinstance(client._channel_refresh_task, CrossSync.Task) + await CrossSync.sleep(0.1) + assert ping_and_warm.call_count == 1 + await client.close() @CrossSync.drop @CrossSync.pytest @@ -427,12 +427,7 @@ async def test__manage_channel_sleeps( uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time_mock: time_mock.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] @@ -441,19 +436,14 @@ async def test__manage_channel_sleeps( try: if refresh_interval is not None: await client._manage_channel( - refresh_interval, refresh_interval + refresh_interval, refresh_interval, grace_period=0 ) else: - await client._manage_channel() + await client._manage_channel(grace_period=0) except asyncio.CancelledError: pass assert sleep.call_count == num_cycles - if CrossSync.is_async: - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - else: - total_sleep = sum( - [call[1]["timeout"] for call in sleep.call_args_list] - ) + total_sleep = sum([call[0][1] for call in sleep.call_args_list]) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -464,10 +454,7 @@ async def test__manage_channel_random(self): import random import threading - sleep_tuple = ( - (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(CrossSync, "event_wait") as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 try: @@ -483,7 +470,7 @@ async def test__manage_channel_random(self): uniform.side_effect = lambda min_, max_: min_ sleep.side_effect = [None, asyncio.CancelledError] try: - await client._manage_channel(min_val, max_val) + await client._manage_channel(min_val, max_val, grace_period=0) except asyncio.CancelledError: pass assert uniform.call_count == 2 @@ -496,28 +483,27 @@ async def test__manage_channel_random(self): @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed - expected_grace = 9 expected_refresh = 0.5 grpc_lib = grpc.aio if CrossSync.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError + RuntimeError ] with mock.patch.object( CrossSync.grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") create_channel.reset_mock() try: await client._manage_channel( refresh_interval_min=expected_refresh, refresh_interval_max=expected_refresh, - grace_period=expected_grace, + grace_period=0, ) - except asyncio.CancelledError: + except RuntimeError: pass assert sleep.call_count == num_cycles + 1 assert create_channel.call_count == num_cycles @@ -935,9 +921,9 @@ async def test_close(self): with mock.patch.object(client.transport, "close", CrossSync.Mock()) as close_mock: await client.close() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() assert task.done() - assert task.cancelled() assert client._channel_refresh_task is None @CrossSync.pytest @@ -954,11 +940,13 @@ async def test_close_with_timeout(self): @CrossSync.pytest async def test_context_manager(self): + from functools import partial # context manager should close the client cleanly close_mock = CrossSync.Mock() true_close = None async with self._make_client(project="project-id", use_emulator=False) as client: - true_close = client.close() + # grab reference to close coro for async test + true_close = partial(client.close) client.close = close_mock assert not client._channel_refresh_task.done() assert client.project == "project-id" @@ -968,7 +956,7 @@ async def test_context_manager(self): if CrossSync.is_async: close_mock.assert_awaited() # actually close the client - await true_close + await true_close() @CrossSync.drop def test_client_ctor_sync(self): @@ -1275,8 +1263,12 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ transport_mock = mock.MagicMock() rpc_mock = CrossSync.Mock() transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock - client._gapic_client._client._transport = transport_mock - client._gapic_client._client._is_universe_domain_valid = True + gapic_client = client._gapic_client + if CrossSync.is_async: + # inner BigtableClient is held as ._client for BigtableAsyncClient + gapic_client = gapic_client._client + gapic_client._transport = transport_mock + gapic_client._is_universe_domain_valid = True table = self._get_target_class()(client, "instance-id", "table-id", profile) try: test_fn = table.__getattribute__(fn_name)