Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Dec 14, 2023
1 parent 9638d2a commit a46183e
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 300 deletions.
4 changes: 0 additions & 4 deletions google/cloud/bigtable/data/_async/_read_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,10 @@
from google.cloud.bigtable.data._helpers import _make_metadata
from google.cloud.bigtable.data._helpers import backoff_generator

from google.api_core.retry_streaming_async import retry_target_stream
from google.api_core.retry import RetryFailureReason
from google.api_core import exceptions as core_exceptions
from google.api_core.grpc_helpers_async import GrpcAsyncStream
from google.cloud.bigtable.data._helpers import _retry_exception_factory

from google.api_core import retry as retries
from google.api_core.retry import exponential_sleep_generator

if TYPE_CHECKING:
from google.cloud.bigtable.data._async.client import TableAsync
Expand Down
4 changes: 0 additions & 4 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,6 @@ async def execute_rpc():
metric_wrapped = operation.wrap_attempt_fn(
execute_rpc, extract_call_metadata=False
)
retry_wrapped = retry(metric_wrapped)
deadline_wrapped = _convert_retry_deadline(
retry_wrapped, operation_timeout, transient_errors, is_async=True
)
return await retries.retry_target_async(
metric_wrapped,
predicate,
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/bigtable/data/_metrics/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ def _exc_to_status(exc: Exception) -> StatusCode:
"""
if isinstance(exc, bt_exceptions._BigtableExceptionGroup):
exc = exc.exceptions[-1]
if hasattr(exc, "grpc_status_code"):
if hasattr(exc, "grpc_status_code") and exc.grpc_status_code is not None:
return exc.grpc_status_code
if exc.__cause__ and hasattr(exc.__cause__, "grpc_status_code"):
if exc.__cause__ and hasattr(exc.__cause__, "grpc_status_code") and exc.__cause__.grpc_status_code is not None:
return exc.__cause__.grpc_status_code
return StatusCode.UNKNOWN

Expand Down
359 changes: 183 additions & 176 deletions tests/unit/data/_async/test_client.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/unit/data/_metrics/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ async def test_wrap_attempt_fn_with_retry(self):
wrap_attampt_fn is meant to be used with retry object. Test using them together
"""
from grpc import StatusCode
from google.api_core.retry_async import AsyncRetry
from google.api_core.retry import AsyncRetry
from google.api_core.exceptions import RetryError

metric = self._make_one(object())
Expand Down
223 changes: 112 additions & 111 deletions tests/unit/data/_metrics/test_rpcs_instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,37 @@

from .._async.test_client import mock_grpc_call

ALL_RPC_PARAMS = (
"fn_name,fn_args,gapic_fn,is_unary,expected_type",
[
("read_rows_stream", (ReadRowsQuery(),), "read_rows", False, OperationType.READ_ROWS),
("read_rows", (ReadRowsQuery(),), "read_rows", False, OperationType.READ_ROWS),
("read_row", (b"row_key",), "read_rows", False, OperationType.READ_ROWS),
("read_rows_sharded", ([ReadRowsQuery()],), "read_rows", False, OperationType.READ_ROWS),
("row_exists", (b"row_key",), "read_rows", False, OperationType.READ_ROWS),
("sample_row_keys", (), "sample_row_keys", False, OperationType.SAMPLE_ROW_KEYS),
("mutate_row", (b"row_key", [mutations.DeleteAllFromRow()]), "mutate_row", False, OperationType.MUTATE_ROW),
(
"bulk_mutate_rows",
([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],),
"mutate_rows",
False,
OperationType.BULK_MUTATE_ROWS
),
("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row", True, OperationType.CHECK_AND_MUTATE),
(
"read_modify_write_row",
(b"row_key", mock.Mock()),
"read_modify_write_row",
True,
OperationType.READ_MODIFY_WRITE
),
]
)

RPC_ARGS = "fn_name,fn_args,gapic_fn,is_unary,expected_type"
RETRYABLE_RPCS = [
("read_rows_stream", (ReadRowsQuery(),), "read_rows", False, OperationType.READ_ROWS),
("read_rows", (ReadRowsQuery(),), "read_rows", False, OperationType.READ_ROWS),
("read_row", (b"row_key",), "read_rows", False, OperationType.READ_ROWS),
("read_rows_sharded", ([ReadRowsQuery()],), "read_rows", False, OperationType.READ_ROWS),
("row_exists", (b"row_key",), "read_rows", False, OperationType.READ_ROWS),
("sample_row_keys", (), "sample_row_keys", False, OperationType.SAMPLE_ROW_KEYS),
("mutate_row", (b"row_key", [mutations.DeleteAllFromRow()]), "mutate_row", False, OperationType.MUTATE_ROW),
(
"bulk_mutate_rows",
([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],),
"mutate_rows",
False,
OperationType.BULK_MUTATE_ROWS
),
]
ALL_RPCS = RETRYABLE_RPCS + [
("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row", True, OperationType.CHECK_AND_MUTATE),
(
"read_modify_write_row",
(b"row_key", mock.Mock()),
"read_modify_write_row",
True,
OperationType.READ_MODIFY_WRITE
),
]

@pytest.mark.parametrize(*ALL_RPC_PARAMS)

@pytest.mark.parametrize(RPC_ARGS, ALL_RPCS)
@pytest.mark.asyncio
async def test_rpc_instrumented(fn_name, fn_args, gapic_fn, is_unary, expected_type):
"""check that all requests attach proper metadata headers"""
Expand All @@ -81,49 +82,49 @@ async def test_rpc_instrumented(fn_name, fn_args, gapic_fn, is_unary, expected_t
grpc_call = mock_grpc_call(unary_response=unary_response, initial_metadata=initial_metadata, trailing_metadata=trailing_metadata)
gapic_mock.return_value = grpc_call
async with BigtableDataClientAsync() as client:
async with TableAsync(client, "instance-id", "table-id") as table:
# customize metrics handlers
mock_metric_handler = mock.Mock()
table._metrics.handlers = [mock_metric_handler]
test_fn = table.__getattribute__(fn_name)
maybe_stream = await test_fn(*fn_args)
# iterate stream if it exists
try:
[i async for i in maybe_stream]
except TypeError:
pass
# check for recorded metrics values
assert mock_metric_handler.on_operation_complete.call_count == 1
found_operation = mock_metric_handler.on_operation_complete.call_args[0][0]
# make sure expected fields were set properly
assert found_operation.op_type == expected_type
now = datetime.datetime.now(datetime.timezone.utc)
assert found_operation.start_time - now < datetime.timedelta(seconds=1)
assert found_operation.duration < 0.1
assert found_operation.duration > 0
assert found_operation.final_status == StatusCode.OK
assert found_operation.cluster_id == cluster_data
assert found_operation.zone == zone_data
# is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded
assert found_operation.is_streaming == ("read_rows" in fn_name)
# check attempts
assert len(found_operation.completed_attempts) == 1
found_attempt = found_operation.completed_attempts[0]
assert found_attempt.end_status == StatusCode.OK
assert found_attempt.start_time - now < datetime.timedelta(seconds=1)
assert found_attempt.duration < 0.1
assert found_attempt.duration > 0
assert found_attempt.start_time >= found_operation.start_time
assert found_attempt.duration <= found_operation.duration
assert found_attempt.gfe_latency == expected_gfe_latency
# first response latency not populated, because no real read_rows chunks processed
assert found_attempt.first_response_latency is None
# no application blocking time or backoff time expected
assert found_attempt.application_blocking_time == 0
assert found_attempt.backoff_before_attempt == 0
table = TableAsync(client, "instance-id", "table-id")
# customize metrics handlers
mock_metric_handler = mock.Mock()
table._metrics.handlers = [mock_metric_handler]
test_fn = table.__getattribute__(fn_name)
maybe_stream = await test_fn(*fn_args)
# iterate stream if it exists
try:
[i async for i in maybe_stream]
except TypeError:
pass
# check for recorded metrics values
assert mock_metric_handler.on_operation_complete.call_count == 1
found_operation = mock_metric_handler.on_operation_complete.call_args[0][0]
# make sure expected fields were set properly
assert found_operation.op_type == expected_type
now = datetime.datetime.now(datetime.timezone.utc)
assert found_operation.start_time - now < datetime.timedelta(seconds=1)
assert found_operation.duration < 0.1
assert found_operation.duration > 0
assert found_operation.final_status == StatusCode.OK
assert found_operation.cluster_id == cluster_data
assert found_operation.zone == zone_data
# is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded
assert found_operation.is_streaming == ("read_rows" in fn_name)
# check attempts
assert len(found_operation.completed_attempts) == 1
found_attempt = found_operation.completed_attempts[0]
assert found_attempt.end_status == StatusCode.OK
assert found_attempt.start_time - now < datetime.timedelta(seconds=1)
assert found_attempt.duration < 0.1
assert found_attempt.duration > 0
assert found_attempt.start_time >= found_operation.start_time
assert found_attempt.duration <= found_operation.duration
assert found_attempt.gfe_latency == expected_gfe_latency
# first response latency not populated, because no real read_rows chunks processed
assert found_attempt.first_response_latency is None
# no application blocking time or backoff time expected
assert found_attempt.application_blocking_time == 0
assert found_attempt.backoff_before_attempt == 0


@pytest.mark.parametrize(*ALL_RPC_PARAMS)
@pytest.mark.parametrize(RPC_ARGS, RETRYABLE_RPCS)
@pytest.mark.asyncio
async def test_rpc_instrumented_multiple_attempts(fn_name, fn_args, gapic_fn, is_unary, expected_type):
"""check that all requests attach proper metadata headers, with a retry"""
Expand All @@ -140,45 +141,45 @@ async def test_rpc_instrumented_multiple_attempts(fn_name, fn_args, gapic_fn, is
grpc_call = mock_grpc_call(unary_response=unary_response)
gapic_mock.side_effect = [Aborted("first attempt failed"), grpc_call]
async with BigtableDataClientAsync() as client:
async with TableAsync(client, "instance-id", "table-id") as table:
# customize metrics handlers
mock_metric_handler = mock.Mock()
table._metrics.handlers = [mock_metric_handler]
test_fn = table.__getattribute__(fn_name)
maybe_stream = await test_fn(*fn_args)
# iterate stream if it exists
try:
[i async for i in maybe_stream]
except TypeError:
pass
# check for recorded metrics values
assert mock_metric_handler.on_operation_complete.call_count == 1
found_operation = mock_metric_handler.on_operation_complete.call_args[0][0]
# make sure expected fields were set properly
assert found_operation.op_type == expected_type
now = datetime.datetime.now(datetime.timezone.utc)
assert found_operation.start_time - now < datetime.timedelta(seconds=1)
assert found_operation.duration < 0.1
assert found_operation.duration > 0
assert found_operation.final_status == StatusCode.OK
# metadata wasn't set, should see default values
assert found_operation.cluster_id == "unspecified"
assert found_operation.zone == "global"
# is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded
assert found_operation.is_streaming == ("read_rows" in fn_name)
# check attempts
assert len(found_operation.completed_attempts) == 2
failure, success = found_operation.completed_attempts
for attempt in [success, failure]:
# check things that should be consistent across attempts
assert attempt.start_time - now < datetime.timedelta(seconds=1)
assert attempt.duration < 0.1
assert attempt.duration > 0
assert attempt.start_time >= found_operation.start_time
assert attempt.duration <= found_operation.duration
assert attempt.application_blocking_time == 0
assert success.end_status == StatusCode.OK
assert failure.end_status == StatusCode.ABORTED
assert success.start_time > failure.start_time + datetime.timedelta(seconds=failure.duration)
assert success.backoff_before_attempt > 0
assert failure.backoff_before_attempt == 0
table = TableAsync(client, "instance-id", "table-id")
# customize metrics handlers
mock_metric_handler = mock.Mock()
table._metrics.handlers = [mock_metric_handler]
test_fn = table.__getattribute__(fn_name)
maybe_stream = await test_fn(*fn_args, retryable_errors=(Aborted,))
# iterate stream if it exists
try:
[_ async for _ in maybe_stream]
except TypeError:
pass
# check for recorded metrics values
assert mock_metric_handler.on_operation_complete.call_count == 1
found_operation = mock_metric_handler.on_operation_complete.call_args[0][0]
# make sure expected fields were set properly
assert found_operation.op_type == expected_type
now = datetime.datetime.now(datetime.timezone.utc)
assert found_operation.start_time - now < datetime.timedelta(seconds=1)
assert found_operation.duration < 0.1
assert found_operation.duration > 0
assert found_operation.final_status == StatusCode.OK
# metadata wasn't set, should see default values
assert found_operation.cluster_id == "unspecified"
assert found_operation.zone == "global"
# is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded
assert found_operation.is_streaming == ("read_rows" in fn_name)
# check attempts
assert len(found_operation.completed_attempts) == 2
failure, success = found_operation.completed_attempts
for attempt in [success, failure]:
# check things that should be consistent across attempts
assert attempt.start_time - now < datetime.timedelta(seconds=1)
assert attempt.duration < 0.1
assert attempt.duration > 0
assert attempt.start_time >= found_operation.start_time
assert attempt.duration <= found_operation.duration
assert attempt.application_blocking_time == 0
assert success.end_status == StatusCode.OK
assert failure.end_status == StatusCode.ABORTED
assert success.start_time > failure.start_time + datetime.timedelta(seconds=failure.duration)
assert success.backoff_before_attempt > 0
assert failure.backoff_before_attempt == 0
6 changes: 4 additions & 2 deletions tests/unit/data/test_read_rows_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def extract_results_from_row(row: Row):
)
@pytest.mark.asyncio
async def test_row_merger_scenario(test_case: ReadRowsTest):
from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric
try:
results = []
instance = mock.Mock()
Expand All @@ -70,7 +71,9 @@ async def test_row_merger_scenario(test_case: ReadRowsTest):
stream_response=[ReadRowsResponse(chunks=test_case.chunks)]
)
chunker = _ReadRowsOperationAsync.chunk_stream(instance, stream)
merger = _ReadRowsOperationAsync.merge_rows(chunker, mock.Mock())
metric = ActiveOperationMetric(0)
metric.start_attempt()
merger = _ReadRowsOperationAsync.merge_rows(chunker, metric)
async for row in merger:
for cell in row:
cell_result = ReadRowsTest.Result(
Expand All @@ -96,7 +99,6 @@ async def test_read_rows_scenario(test_case: ReadRowsTest):
try:
client = BigtableDataClientAsync()
table = client.get_table("instance", "table")
await table._register_instance_task # to avoid warning
results = []
with mock.patch.object(
table.client._gapic_client, "read_rows", mock.AsyncMock()
Expand Down

0 comments on commit a46183e

Please sign in to comment.