Skip to content

Commit

Permalink
mypy changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kboroszko committed Aug 7, 2024
1 parent 4196fff commit 712edef
Show file tree
Hide file tree
Showing 18 changed files with 144 additions and 75 deletions.
1 change: 1 addition & 0 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ async def execute_query(
return ExecuteQueryIteratorAsync(
self,
instance_id,
app_profile_id,
request_body,
attempt_timeout,
operation_timeout,
Expand Down
1 change: 0 additions & 1 deletion google/cloud/bigtable/data/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,3 @@ class InvalidExecuteQueryResponse(core_exceptions.GoogleAPICallError):

class ParameterTypeInferenceFailed(ValueError):
"""Exception raised when query parameter types were not provided and cannot be inferred."""

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from google.cloud.bigtable.data.exceptions import InvalidExecuteQueryResponse
from google.cloud.bigtable.data.execute_query.values import QueryResultRow
from google.cloud.bigtable.data.execute_query.metadata import Metadata
from google.cloud.bigtable.data.execute_query.metadata import Metadata, ProtoMetadata
from google.cloud.bigtable.data.execute_query._reader import (
_QueryResultRowReader,
_Reader,
Expand All @@ -46,14 +46,14 @@

class ExecuteQueryIteratorAsync:
"""
ExecuteQueryIteratorAsync handles collecting streaming responses from the
ExecuteQueryIteratorAsync handles collecting streaming responses from the
ExecuteQuery RPC and parsing them to `QueryResultRow`s.
ExecuteQueryIteratorAsync implements Asynchronous Iterator interface and can
be used with "async for" syntax. It is also a context manager.
It is **not thread-safe**. It should not be used by multiple asyncio Tasks.
Args:
client (google.cloud.bigtable.data._async.BigtableDataClientAsync): bigtable client
instance_id (str): id of the instance on which the query is executed
Expand All @@ -68,20 +68,23 @@ class ExecuteQueryIteratorAsync:
req_metadata (Sequence[Tuple[str, str]]): metadata used while sending the gRPC request
retryable_excs (List[type[Exception]]): a list of errors that will be retried if encountered.
"""

def __init__(
self,
client: Any,
instance_id: str,
app_profile_id: Optional[str],
request_body: Dict[str, Any],
attempt_timeout: float | None,
operation_timeout: float,
req_metadata: Sequence[Tuple[str, str]],
retryable_excs: List[type[Exception]],
) -> None:
self._table_name = None
self._app_profile_id = app_profile_id
self._client = client
self._instance_id = instance_id
self._byte_cursor = _ByteCursor()
self._byte_cursor = _ByteCursor[ProtoMetadata]()
self._reader: _Reader[QueryResultRow] = _QueryResultRowReader(self._byte_cursor)
self._result_generator = self._next_impl()
self._register_instance_task = None
Expand Down Expand Up @@ -112,6 +115,10 @@ def __init__(
def is_closed(self):
return self._is_closed

@property
def app_profile_id(self):
return self._app_profile_id

@property
def table_name(self):
return self._table_name
Expand Down Expand Up @@ -179,7 +186,7 @@ def __aiter__(self):

async def metadata(self) -> Optional[Metadata]:
"""
Returns query metadata from the server or None if the iterator was
Returns query metadata from the server or None if the iterator was
explicitly closed.
"""
if self._is_closed:
Expand Down
13 changes: 8 additions & 5 deletions google/cloud/bigtable/data/execute_query/_byte_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Any, Generic, Optional, TypeVar

from google.cloud.bigtable_v2 import ExecuteQueryResponse
from google.cloud.bigtable.data.execute_query.metadata import (
Metadata,
_pb_metadata_to_metadata_types,
)

MT = TypeVar("MT", bound=Metadata) # metadata type

class _ByteCursor:

class _ByteCursor(Generic[MT]):
"""
Buffers bytes from `ExecuteQuery` responses until resume_token is received or end-of-stream
is reached. :class:`google.cloud.bigtable_v2.types.bigtable.ExecuteQueryResponse` obtained from
Expand All @@ -35,13 +37,13 @@ class _ByteCursor:
"""

def __init__(self):
self._metadata: Optional[Metadata] = None
self._metadata: Optional[MT] = None
self._buffer = bytearray()
self._resume_token = None
self._last_response_results_field = None

@property
def metadata(self) -> Optional[Metadata]:
def metadata(self) -> Optional[MT]:
"""
Returns:
Metadata or None: Metadata read from the first response of the stream
Expand Down Expand Up @@ -91,7 +93,8 @@ def consume_metadata(self, response: ExecuteQueryResponse) -> None:
raise ValueError("Invalid state - metadata already consumed")

if "metadata" in response:
self._metadata = _pb_metadata_to_metadata_types(response.metadata)
metadata: Any = _pb_metadata_to_metadata_types(response.metadata)
self._metadata = metadata
else:
raise ValueError("Invalid parameter - response without metadata")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Any, Optional
from typing import Any, Dict, Optional
import datetime
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
from google.cloud.bigtable.data.exceptions import ParameterTypeInferenceFailed
Expand All @@ -23,14 +23,14 @@
def _format_execute_query_params(
params: Optional[Dict[str, ExecuteQueryValueType]],
parameter_types: Optional[Dict[str, SqlType.Type]],
) -> Dict:
) -> Any:
"""
Takes a dictionary of param_name -> param_value and optionally parameter types.
If the parameters types are not provided, this function tries to infer them.
Args:
params (Optional[Dict[str, ExecuteQueryValueType]]): mapping from parameter names
like they appear in query (without @ at the beginning) to their values.
like they appear in query (without @ at the beginning) to their values.
Only values of type ExecuteQueryValueType are permitted.
parameter_types (Optional[Dict[str, SqlType.Type]]): mapping of parameter names
to their types.
Expand Down Expand Up @@ -69,7 +69,9 @@ def _format_execute_query_params(
return result_values


def _convert_value_to_pb_value_dict(value: ExecuteQueryValueType, param_type: SqlType.Type) -> Dict:
def _convert_value_to_pb_value_dict(
value: ExecuteQueryValueType, param_type: SqlType.Type
) -> Any:
"""
Takes a value and converts it to a dictionary parsable to a protobuf.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from typing import Any, Callable, Dict, Type
from google.cloud.bigtable.data.execute_query.values import Struct
from google.cloud.bigtable.data.execute_query.metadata import SqlType
from google.cloud.bigtable_v2 import Value as PBValue
Expand All @@ -32,7 +32,7 @@
}


def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> list:
def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> Any:
"""
used for parsing an array represented as a protobuf to a python list.
"""
Expand All @@ -46,18 +46,17 @@ def _parse_array_type(value: PBValue, metadata_type: SqlType.Array) -> list:
)


def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> dict:
def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> Any:
"""
used for parsing a map represented as a protobuf to a python dict.
Values of type `Map` are stored in a `Value.array_value` where each entry
is another `Value.array_value` with two elements (the key and the value,
in that order).
Normally encoded Map values won't have repeated keys, however, the client
Normally encoded Map values won't have repeated keys, however, the client
must handle the case in which they do. If the same key appears
multiple times, the _last_ value takes precedence.
"""


try:
return dict(
Expand All @@ -79,7 +78,7 @@ def _parse_map_type(value: PBValue, metadata_type: SqlType.Map) -> dict:

def _parse_struct_type(value: PBValue, metadata_type: SqlType.Struct) -> Struct:
"""
used for parsing a struct represented as a protobuf to a
used for parsing a struct represented as a protobuf to a
google.cloud.bigtable.data.execute_query.Struct
"""
if len(value.array_value.values) != len(metadata_type.fields):
Expand All @@ -102,7 +101,7 @@ def _parse_timestamp_type(
return DatetimeWithNanoseconds.from_timestamp_pb(value.timestamp_value)


_TYPE_PARSERS = {
_TYPE_PARSERS: Dict[Type[SqlType.Type], Callable[[PBValue, Any], Any]] = {
SqlType.Timestamp: _parse_timestamp_type,
SqlType.Struct: _parse_struct_type,
SqlType.Array: _parse_array_type,
Expand All @@ -112,7 +111,7 @@ def _parse_timestamp_type(

def _parse_pb_value_to_python_value(value: PBValue, metadata_type: SqlType.Type) -> Any:
"""
used for converting the value represented as a protobufs to a python object.
used for converting the value represented as a protobufs to a python object.
"""
value_kind = value.WhichOneof("kind")
if not value_kind:
Expand Down
20 changes: 12 additions & 8 deletions google/cloud/bigtable/data/execute_query/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Optional,
List,
Sequence,
cast,
)
from abc import ABC, abstractmethod

Expand All @@ -32,7 +33,7 @@
from google.cloud.bigtable.helpers import batched

from google.cloud.bigtable.data.execute_query.values import QueryResultRow
from google.cloud.bigtable.data.execute_query.metadata import Metadata
from google.cloud.bigtable.data.execute_query.metadata import ProtoMetadata


T = TypeVar("T")
Expand Down Expand Up @@ -83,7 +84,7 @@ class _QueryResultRowReader(_Reader[QueryResultRow]):
:class:`google.cloud.bigtable.byte_cursor._ByteCursor` passed in the constructor.
"""

def __init__(self, byte_cursor: _ByteCursor):
def __init__(self, byte_cursor: _ByteCursor[ProtoMetadata]):
"""
Constructs new instance of ``_QueryResultRowReader``.
Expand All @@ -97,14 +98,15 @@ def __init__(self, byte_cursor: _ByteCursor):
self._byte_cursor = byte_cursor

@property
def _metadata(self) -> Optional[Metadata]:
def _metadata(self) -> Optional[ProtoMetadata]:
return self._byte_cursor.metadata

def _construct_query_result_row(
self, values: Sequence[PBValue]
) -> List[QueryResultRow]:
def _construct_query_result_row(self, values: Sequence[PBValue]) -> QueryResultRow:
result = QueryResultRow()
columns = self._metadata.columns
# The logic, not defined by mypy types, ensures that the value of
# "metadata" is never null at the time it is retrieved here
metadata = cast(ProtoMetadata, self._metadata)
columns = metadata.columns

assert len(values) == len(
columns
Expand All @@ -125,7 +127,9 @@ def consume(self, bytes_to_consume: bytes) -> Optional[Iterable[QueryResultRow]]

self._values.extend(self._parse_proto_rows(bytes_to_consume))

num_columns = len(self._metadata.columns)
# The logic, not defined by mypy types, ensures that the value of
# "metadata" is never null at the time it is retrieved here
num_columns = len(cast(ProtoMetadata, self._metadata).columns)

if len(self._values) < num_columns:
return None
Expand Down
Loading

0 comments on commit 712edef

Please sign in to comment.