Skip to content

Commit

Permalink
fixed grace_period for sync client
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Nov 11, 2024
1 parent 008e724 commit 7fb2134
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 41 deletions.
10 changes: 8 additions & 2 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,14 @@ async def _manage_channel(
await self._ping_and_warm_instances(channel=new_channel)
# cycle channel out of use, with long grace window before closure
self.transport._grpc_channel = new_channel
await old_channel.close(grace_period)
# subtract the time spent waiting for the channel to be replaced
# give old_channel a chance to complete existing rpcs
if CrossSync.is_async:
await old_channel.close(grace_period)
else:
if grace_period:
self._is_closed.wait(grace_period)
old_channel.close()
# subtract thed time spent waiting for the channel to be replaced
next_refresh = random.uniform(refresh_interval_min, refresh_interval_max)
next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0)

Expand Down
70 changes: 31 additions & 39 deletions tests/unit/data/_async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ async def test__start_background_channel_refresh_task_exists(self):
@CrossSync.pytest
async def test__start_background_channel_refresh(self):
# should create background tasks for each channel
client = self._make_client(project="project-id", use_emulator=False)
ping_and_warm = CrossSync.Mock()
client._ping_and_warm_instances = ping_and_warm
client._start_background_channel_refresh()
assert client._channel_refresh_task is not None
assert isinstance(client._channel_refresh_task, asyncio.Task)
await asyncio.sleep(0.1)
assert ping_and_warm.call_count == 1
await client.close()
client = self._make_client(project="project-id")
with mock.patch.object(client, "_ping_and_warm_instances", CrossSync.Mock()) as ping_and_warm:
client._emulator_host = None
client._start_background_channel_refresh()
assert client._channel_refresh_task is not None
assert isinstance(client._channel_refresh_task, CrossSync.Task)
await CrossSync.sleep(0.1)
assert ping_and_warm.call_count == 1
await client.close()

@CrossSync.drop
@CrossSync.pytest
Expand Down Expand Up @@ -427,12 +427,7 @@ async def test__manage_channel_sleeps(
uniform.side_effect = lambda min_, max_: min_
with mock.patch.object(time, "time") as time_mock:
time_mock.return_value = 0
sleep_tuple = (
(asyncio, "sleep")
if CrossSync.is_async
else (threading.Event, "wait")
)
with mock.patch.object(*sleep_tuple) as sleep:
with mock.patch.object(CrossSync, "event_wait") as sleep:
sleep.side_effect = [None for i in range(num_cycles - 1)] + [
asyncio.CancelledError
]
Expand All @@ -441,19 +436,14 @@ async def test__manage_channel_sleeps(
try:
if refresh_interval is not None:
await client._manage_channel(
refresh_interval, refresh_interval
refresh_interval, refresh_interval, grace_period=0
)
else:
await client._manage_channel()
await client._manage_channel(grace_period=0)
except asyncio.CancelledError:
pass
assert sleep.call_count == num_cycles
if CrossSync.is_async:
total_sleep = sum([call[0][0] for call in sleep.call_args_list])
else:
total_sleep = sum(
[call[1]["timeout"] for call in sleep.call_args_list]
)
total_sleep = sum([call[0][1] for call in sleep.call_args_list])
assert (
abs(total_sleep - expected_sleep) < 0.1
), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}"
Expand All @@ -464,10 +454,7 @@ async def test__manage_channel_random(self):
import random
import threading

sleep_tuple = (
(asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait")
)
with mock.patch.object(*sleep_tuple) as sleep:
with mock.patch.object(CrossSync, "event_wait") as sleep:
with mock.patch.object(random, "uniform") as uniform:
uniform.return_value = 0
try:
Expand All @@ -483,7 +470,7 @@ async def test__manage_channel_random(self):
uniform.side_effect = lambda min_, max_: min_
sleep.side_effect = [None, asyncio.CancelledError]
try:
await client._manage_channel(min_val, max_val)
await client._manage_channel(min_val, max_val, grace_period=0)
except asyncio.CancelledError:
pass
assert uniform.call_count == 2
Expand All @@ -496,28 +483,27 @@ async def test__manage_channel_random(self):
@pytest.mark.parametrize("num_cycles", [0, 1, 10, 100])
async def test__manage_channel_refresh(self, num_cycles):
# make sure that channels are properly refreshed
expected_grace = 9
expected_refresh = 0.5
grpc_lib = grpc.aio if CrossSync.is_async else grpc
new_channel = grpc_lib.insecure_channel("localhost:8080")

with mock.patch.object(CrossSync, "event_wait") as sleep:
sleep.side_effect = [None for i in range(num_cycles)] + [
asyncio.CancelledError
RuntimeError
]
with mock.patch.object(
CrossSync.grpc_helpers, "create_channel"
) as create_channel:
create_channel.return_value = new_channel
client = self._make_client(project="project-id", use_emulator=False)
client = self._make_client(project="project-id")
create_channel.reset_mock()
try:
await client._manage_channel(
refresh_interval_min=expected_refresh,
refresh_interval_max=expected_refresh,
grace_period=expected_grace,
grace_period=0,
)
except asyncio.CancelledError:
except RuntimeError:
pass
assert sleep.call_count == num_cycles + 1
assert create_channel.call_count == num_cycles
Expand Down Expand Up @@ -935,9 +921,9 @@ async def test_close(self):
with mock.patch.object(client.transport, "close", CrossSync.Mock()) as close_mock:
await client.close()
close_mock.assert_called_once()
close_mock.assert_awaited()
if CrossSync.is_async:
close_mock.assert_awaited()
assert task.done()
assert task.cancelled()
assert client._channel_refresh_task is None

@CrossSync.pytest
Expand All @@ -954,11 +940,13 @@ async def test_close_with_timeout(self):

@CrossSync.pytest
async def test_context_manager(self):
from functools import partial
# context manager should close the client cleanly
close_mock = CrossSync.Mock()
true_close = None
async with self._make_client(project="project-id", use_emulator=False) as client:
true_close = client.close()
# grab reference to close coro for async test
true_close = partial(client.close)
client.close = close_mock
assert not client._channel_refresh_task.done()
assert client.project == "project-id"
Expand All @@ -968,7 +956,7 @@ async def test_context_manager(self):
if CrossSync.is_async:
close_mock.assert_awaited()
# actually close the client
await true_close
await true_close()

@CrossSync.drop
def test_client_ctor_sync(self):
Expand Down Expand Up @@ -1275,8 +1263,12 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_
transport_mock = mock.MagicMock()
rpc_mock = CrossSync.Mock()
transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock
client._gapic_client._client._transport = transport_mock
client._gapic_client._client._is_universe_domain_valid = True
gapic_client = client._gapic_client
if CrossSync.is_async:
# inner BigtableClient is held as ._client for BigtableAsyncClient
gapic_client = gapic_client._client
gapic_client._transport = transport_mock
gapic_client._is_universe_domain_valid = True
table = self._get_target_class()(client, "instance-id", "table-id", profile)
try:
test_fn = table.__getattribute__(fn_name)
Expand Down

0 comments on commit 7fb2134

Please sign in to comment.