Skip to content

Commit

Permalink
pr fixes 1
Browse files Browse the repository at this point in the history
  • Loading branch information
kboroszko committed Aug 6, 2024
1 parent 88e0492 commit 4196fff
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 25 deletions.
3 changes: 1 addition & 2 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ async def execute_query(
Defaults to the 20 seconds.
If None, defaults to operation_timeout.
- retryable_errors: a list of errors that will be retried if encountered.
Defaults to the Table's default_read_rows_retryable_errors
Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted)
Returns:
- an asynchronous iterator that yields rows returned by the query
Raises:
Expand Down Expand Up @@ -520,7 +520,6 @@ async def execute_query(
return ExecuteQueryIteratorAsync(
self,
instance_id,
app_profile_id,
request_body,
attempt_timeout,
operation_timeout,
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/bigtable/data/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,4 +318,5 @@ class InvalidExecuteQueryResponse(core_exceptions.GoogleAPICallError):


class ParameterTypeInferenceFailed(ValueError):
pass
"""Exception raised when query parameter types were not provided and cannot be inferred."""

Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,40 @@


class ExecuteQueryIteratorAsync:
"""
ExecuteQueryIteratorAsync handles collecting streaming responses from the
ExecuteQuery RPC and parsing them to `QueryResultRow`s.
ExecuteQueryIteratorAsync implements Asynchronous Iterator interface and can
be used with "async for" syntax. It is also a context manager.
It is **not thread-safe**. It should not be used by multiple asyncio Tasks.
Args:
client (google.cloud.bigtable.data._async.BigtableDataClientAsync): bigtable client
instance_id (str): id of the instance on which the query is executed
request_body (Dict[str, Any]): dict representing the body of the ExecuteQueryRequest
attempt_timeout (float | None): the time budget for the entire operation, in seconds.
Failed requests will be retried within the budget.
Defaults to 600 seconds.
operation_timeout (float): the time budget for an individual network request, in seconds.
If it takes longer than this time to complete, the request will be cancelled with
a DeadlineExceeded exception, and a retry will be attempted.
Defaults to the 20 seconds. If None, defaults to operation_timeout.
req_metadata (Sequence[Tuple[str, str]]): metadata used while sending the gRPC request
retryable_excs (List[type[Exception]]): a list of errors that will be retried if encountered.
"""
def __init__(
self,
client: Any,
instance_id: str,
app_profile_id: Optional[str],
request_body: Dict[str, Any],
attempt_timeout: float | None,
operation_timeout: float,
req_metadata: Sequence[Tuple[str, str]],
retryable_excs: List[type[Exception]],
) -> None:
self._table_name = None
self._app_profile_id = app_profile_id
self._client = client
self._instance_id = instance_id
self._byte_cursor = _ByteCursor()
Expand Down Expand Up @@ -91,15 +112,14 @@ def __init__(
def is_closed(self):
return self._is_closed

@property
def app_profile_id(self):
return self._app_profile_id

@property
def table_name(self):
return self._table_name

async def _make_request_with_resume_token(self):
"""
perfoms the rpc call using the correct resume token.
"""
resume_token = self._byte_cursor.prepare_for_new_request()
request = ExecuteQueryRequestPB(
{
Expand All @@ -115,11 +135,19 @@ async def _make_request_with_resume_token(self):
)

async def _await_metadata(self) -> None:
"""
If called before the first response was recieved, the first response
is awaited as part of this call.
"""
if self._byte_cursor.metadata is None:
metadata_msg = await self._async_stream.__anext__()
self._byte_cursor.consume_metadata(metadata_msg)

async def _next_impl(self) -> AsyncIterator[QueryResultRow]:
"""
Generator wrapping the response stream which parses the stream results
and returns full `QueryResultRow`s.
"""
await self._await_metadata()

async for response in self._async_stream:
Expand Down Expand Up @@ -150,6 +178,10 @@ def __aiter__(self):
return self

async def metadata(self) -> Optional[Metadata]:
"""
Returns query metadata from the server or None if the iterator was
explicitly closed.
"""
if self._is_closed:
return None
# Metadata should be present in the first response in a stream.
Expand All @@ -161,6 +193,9 @@ async def metadata(self) -> Optional[Metadata]:
return self._byte_cursor.metadata

async def close(self) -> None:
"""
Cancel all background tasks. Should be called all rows were processed.
"""
if self._is_closed:
return
self._is_closed = True
Expand Down
39 changes: 36 additions & 3 deletions google/cloud/bigtable/data/execute_query/_parameters_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,26 @@
def _format_execute_query_params(
params: Optional[Dict[str, ExecuteQueryValueType]],
parameter_types: Optional[Dict[str, SqlType.Type]],
):
) -> Dict:
"""
Takes a dictionary of param_name -> param_value and optionally parameter types.
If the parameters types are not provided, this function tries to infer them.
Args:
params (Optional[Dict[str, ExecuteQueryValueType]]): mapping from parameter names
like they appear in query (without @ at the beginning) to their values.
Only values of type ExecuteQueryValueType are permitted.
parameter_types (Optional[Dict[str, SqlType.Type]]): mapping of parameter names
to their types.
Raises:
ValueError: raised when parameter types cannot be inferred and were not
provided explicitly.
Returns:
dictionary prasable to a protobuf represenging parameters as defined
in ExecuteQueryRequest.params
"""
if not params:
return {}
parameter_types = parameter_types or {}
Expand All @@ -50,7 +69,17 @@ def _format_execute_query_params(
return result_values


def _convert_value_to_pb_value_dict(value: Any, param_type: SqlType.Type):
def _convert_value_to_pb_value_dict(value: ExecuteQueryValueType, param_type: SqlType.Type) -> Dict:
"""
Takes a value and converts it to a dictionary parsable to a protobuf.
Args:
value (ExecuteQueryValueType): value
param_type (SqlType.Type): object describing which ExecuteQuery type the value represents.
Returns:
dictionary parsable to a protobuf.
"""
# type field will be set only in top-level Value.
value_dict = param_type._to_value_pb_dict(value)
value_dict["type"] = param_type._to_type_pb_dict()
Expand All @@ -68,7 +97,11 @@ def _convert_value_to_pb_value_dict(value: Any, param_type: SqlType.Type):
]


def _detect_type(value):
def _detect_type(value: ExecuteQueryValueType) -> SqlType.Type:
"""
Infers the ExecuteQuery type based on value. Raises error if type is amiguous.
raises ParameterTypeInferenceFailed if not possible.
"""
if value is None:
raise ParameterTypeInferenceFailed(
"Cannot infer type of None, please provide the type manually."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@


def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> list:
"""
used for parsing an array represented as a protobuf to a python list.
"""
return list(
map(
lambda val: _parse_pb_value_to_python_value(
Expand All @@ -44,12 +47,17 @@ def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> list:


def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> dict:
# Values of type `Map` are stored in a `Value.array_value` where each entry
# is another `Value.array_value` with two elements (the key and the value,
# in that order).
# Normally encoded Map values won't have repeated keys, however, clients are
# expected to handle the case in which they do. If the same key appears
# multiple times, the _last_ value takes precedence.
"""
used for parsing a map represented as a protobuf to a python dict.
Values of type `Map` are stored in a `Value.array_value` where each entry
is another `Value.array_value` with two elements (the key and the value,
in that order).
Normally encoded Map values won't have repeated keys, however, the client
must handle the case in which they do. If the same key appears
multiple times, the _last_ value takes precedence.
"""


try:
return dict(
Expand All @@ -70,6 +78,10 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> dict:


def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct:
"""
used for parsing a struct represented as a protobuf to a
google.cloud.bigtable.data.execute_query.Struct
"""
if len(value.array_value.values) != len(metadata_type.fields):
raise ValueError("Mismatched lengths of values and types.")

Expand All @@ -84,6 +96,9 @@ def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct:
def _parse_timestamp_type(
value: PBValue, metadata_type: SqlType.Timestamp
) -> DatetimeWithNanoseconds:
"""
used for parsing a timestamp represented as a protobuf to DatetimeWithNanoseconds
"""
return DatetimeWithNanoseconds.from_timestamp_pb(value.timestamp_value)


Expand All @@ -96,6 +111,9 @@ def _parse_timestamp_type(


def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type) -> Any:
"""
used for converting the value represented as a protobufs to a python object.
"""
value_kind = value.WhichOneof("kind")
if not value_kind:
return None
Expand Down
2 changes: 0 additions & 2 deletions google/cloud/bigtable_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@
"RowSet",
"SampleRowKeysRequest",
"SampleRowKeysResponse",
"ExecuteQueryRequest",
"ExecuteQueryResponse",
"StreamContinuationToken",
"StreamContinuationTokens",
"StreamPartition",
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/data/test_execute_query_parameters_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,6 @@ def test_instance_execute_query_parameters_simple_types_parsing(
},
None,
)
print("RESULT")
print(type(result["test"][value_field]))
print(result["test"][value_field])
print(type(expected_value))
print(expected_value)
assert result["test"][value_field] == expected_value
assert type_field in result["test"]["type"]

Expand Down

0 comments on commit 4196fff

Please sign in to comment.