From 52ec52adb539c5c9b246c0db4373f697fcd02313 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Jan 2024 17:22:45 -0800 Subject: [PATCH] use the same client init function in unit tests --- tests/unit/data/_async/test_client.py | 142 +++++++++++--------------- 1 file changed, 59 insertions(+), 83 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 907def6af..26660ad18 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -45,25 +45,30 @@ ) +def _make_client(*args, use_emulator=True, **kwargs): + import os + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + env_mask = {} + # by default, use emulator mode to avoid auth issues in CI + # emulator mode must be disabled by tests that check channel pooling/refresh background tasks + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + else: + # set some default values + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return BigtableDataClientAsync(*args, **kwargs) + + class TestBigtableDataClientAsync: def _get_target_class(self): from google.cloud.bigtable.data._async.client import BigtableDataClientAsync return BigtableDataClientAsync - def _make_one(self, *args, use_emulator=True, **kwargs): - import os - env_mask = {} - # by default, use emulator mode to avoid auth issues in CI - # emulator mode must be disabled by tests that check channel pooling/refresh background tasks - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - else: - # set some default values - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return self._get_target_class()(*args, **kwargs) + def _make_one(self, *args, **kwargs): + return _make_client(*args, **kwargs) @pytest.mark.asyncio async def test_ctor(self): @@ -1321,11 +1326,6 @@ class TestReadRows: Tests for table.read_rows and related methods. """ - def _make_client(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync(*args, **kwargs) - def _make_table(self, *args, **kwargs): from google.cloud.bigtable.data._async.client import TableAsync @@ -1694,7 +1694,7 @@ async def test_read_rows_default_timeout_override(self): @pytest.mark.asyncio async def test_read_row(self): """Test reading a single row""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1722,7 +1722,7 @@ async def test_read_row(self): @pytest.mark.asyncio async def test_read_row_w_filter(self): """Test reading a single row with an added filter""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1755,7 +1755,7 @@ async def test_read_row_w_filter(self): @pytest.mark.asyncio async def test_read_row_no_response(self): """should return None if row does not exist""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1790,7 +1790,7 @@ async def test_read_row_no_response(self): @pytest.mark.asyncio async def test_row_exists(self, return_value, expected_result): """Test checking for row existence""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1825,14 +1825,10 @@ async def test_row_exists(self, return_value, expected_result): class TestReadRowsSharded: - def _make_client(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync(*args, **kwargs) @pytest.mark.asyncio async def test_read_rows_sharded_empty_query(self): - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as exc: await table.read_rows_sharded([]) @@ -1843,7 +1839,7 @@ async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "read_rows" @@ -1869,7 +1865,7 @@ async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): """ Each query should trigger a separate read_rows call """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: query_list = [ReadRowsQuery() for _ in range(n_queries)] @@ -1884,7 +1880,7 @@ async def test_read_rows_sharded_errors(self): from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedQueryShardError - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = RuntimeError("mock error") @@ -1915,7 +1911,7 @@ async def mock_call(*args, **kwargs): await asyncio.sleep(0.1) return [mock.Mock()] - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1988,10 +1984,6 @@ async def test_read_rows_sharded_batching(self): class TestSampleRowKeys: - def _make_client(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync(*args, **kwargs) async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse @@ -2009,7 +2001,7 @@ async def test_sample_row_keys(self): (b"test_2", 100), (b"test_3", 200), ] - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "sample_row_keys", AsyncMock() @@ -2029,7 +2021,7 @@ async def test_sample_row_keys_bad_timeout(self): """ should raise error if timeout is negative """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.sample_row_keys(operation_timeout=-1) @@ -2042,7 +2034,7 @@ async def test_sample_row_keys_bad_timeout(self): async def test_sample_row_keys_default_timeout(self): """Should fallback to using table default operation_timeout""" expected_timeout = 99 - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( "i", "t", @@ -2068,7 +2060,7 @@ async def test_sample_row_keys_gapic_params(self): expected_profile = "test1" instance = "instance_name" table_id = "my_table" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( instance, table_id, app_profile_id=expected_profile ) as table: @@ -2101,7 +2093,7 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "sample_row_keys", AsyncMock() @@ -2130,7 +2122,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio """ non-retryable errors should cause a raise """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "sample_row_keys", AsyncMock() @@ -2141,10 +2133,6 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio class TestMutateRow: - def _make_client(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync(*args, **kwargs) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -2167,7 +2155,7 @@ def _make_client(self, *args, **kwargs): async def test_mutate_row(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2207,7 +2195,7 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2237,7 +2225,7 @@ async def test_mutate_row_non_idempotent_retryable_errors( """ Non-idempotent mutations should not be retried """ - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2265,7 +2253,7 @@ async def test_mutate_row_non_idempotent_retryable_errors( ) @pytest.mark.asyncio async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2288,7 +2276,7 @@ async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): async def test_mutate_row_metadata(self, include_app_profile): """request should attach metadata headers""" profile = "profile" if include_app_profile else None - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("i", "t", app_profile_id=profile) as table: with mock.patch.object( client._gapic_client, "mutate_row", AsyncMock() @@ -2310,7 +2298,7 @@ async def test_mutate_row_metadata(self, include_app_profile): @pytest.mark.parametrize("mutations", [[], None]) @pytest.mark.asyncio async def test_mutate_row_no_mutations(self, mutations): - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.mutate_row("key", mutations=mutations) @@ -2318,10 +2306,6 @@ async def test_mutate_row_no_mutations(self, mutations): class TestBulkMutateRows: - def _make_client(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync(*args, **kwargs) async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse @@ -2371,7 +2355,7 @@ async def generator(): async def test_bulk_mutate_rows(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2395,7 +2379,7 @@ async def test_bulk_mutate_rows(self, mutation_arg): @pytest.mark.asyncio async def test_bulk_mutate_rows_multiple_entries(self): """Test mutations with no errors""" - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2436,7 +2420,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2482,7 +2466,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2522,7 +2506,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2560,7 +2544,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2602,7 +2586,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2638,7 +2622,7 @@ async def test_bulk_mutate_error_index(self): MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2680,7 +2664,7 @@ async def test_bulk_mutate_error_recovery(self): """ from google.api_core.exceptions import DeadlineExceeded - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: table = client.get_table("instance", "table") with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: # fail with a retryable error, then a non-retryable one @@ -2699,10 +2683,6 @@ async def test_bulk_mutate_error_recovery(self): class TestCheckAndMutateRow: - def _make_client(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync(*args, **kwargs) @pytest.mark.parametrize("gapic_result", [True, False]) @pytest.mark.asyncio @@ -2710,7 +2690,7 @@ async def test_check_and_mutate(self, gapic_result): from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse app_profile = "app_profile_id" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( "instance", "table", app_profile_id=app_profile ) as table: @@ -2750,7 +2730,7 @@ async def test_check_and_mutate(self, gapic_result): @pytest.mark.asyncio async def test_check_and_mutate_bad_timeout(self): """Should raise error if operation_timeout < 0""" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.check_and_mutate_row( @@ -2768,7 +2748,7 @@ async def test_check_and_mutate_single_mutations(self): from google.cloud.bigtable.data.mutations import SetCell from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2796,7 +2776,7 @@ async def test_check_and_mutate_predicate_object(self): mock_predicate = mock.Mock() predicate_pb = {"predicate": "dict"} mock_predicate._to_pb.return_value = predicate_pb - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2824,7 +2804,7 @@ async def test_check_and_mutate_mutations_parsing(self): for idx, mutation in enumerate(mutations): mutation._to_pb.return_value = f"fake {idx}" mutations.append(DeleteAllFromRow()) - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2852,10 +2832,6 @@ async def test_check_and_mutate_mutations_parsing(self): class TestReadModifyWriteRow: - def _make_client(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync(*args, **kwargs) @pytest.mark.parametrize( "call_rules,expected_rules", @@ -2883,7 +2859,7 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules """ Test that the gapic call is called with given rules """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2897,7 +2873,7 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules @pytest.mark.parametrize("rules", [[], None]) @pytest.mark.asyncio async def test_read_modify_write_no_rules(self, rules): - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.read_modify_write_row("key", rules=rules) @@ -2909,7 +2885,7 @@ async def test_read_modify_write_call_defaults(self): table_id = "table1" project = "project1" row_key = "row_key1" - async with self._make_client(project=project) as client: + async with _make_client(project=project) as client: async with client.get_table(instance, table_id) as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2930,7 +2906,7 @@ async def test_read_modify_write_call_overrides(self): row_key = b"row_key1" expected_timeout = 12345 profile_id = "profile1" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( "instance", "table_id", app_profile_id=profile_id ) as table: @@ -2951,7 +2927,7 @@ async def test_read_modify_write_call_overrides(self): @pytest.mark.asyncio async def test_read_modify_write_string_key(self): row_key = "string_row_key1" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2971,7 +2947,7 @@ async def test_read_modify_write_row_building(self): from google.cloud.bigtable_v2.types import Row as RowPB mock_response = ReadModifyWriteRowResponse(row=RowPB()) - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row"