Skip to content

Commit

Permalink
chore: add cross_sync annotations (#1000)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Nov 20, 2024
1 parent 511abb1 commit 7ea3c23
Show file tree
Hide file tree
Showing 32 changed files with 3,430 additions and 3,297 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/conformance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ jobs:
matrix:
test-version: [ "v0.0.2" ]
py-version: [ 3.8 ]
client-type: [ "Async v3", "Legacy" ]
client-type: [ "async", "legacy" ]
fail-fast: false
name: "${{ matrix.client-type }} Client / Python ${{ matrix.py-version }} / Test Tag ${{ matrix.test-version }}"
name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}"
steps:
- uses: actions/checkout@v4
name: "Checkout python-bigtable"
Expand Down
3 changes: 1 addition & 2 deletions .kokoro/conformance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ PROXY_ARGS=""
TEST_ARGS=""
if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then
echo "Using legacy client"
PROXY_ARGS="--legacy-client"
# legacy client does not expose mutate_row. Disable those tests
TEST_ARGS="-skip TestMutateRow_"
fi

# Build and start the proxy in a separate process
PROXY_PORT=9999
pushd test_proxy
nohup python test_proxy.py --port $PROXY_PORT $PROXY_ARGS &
nohup python test_proxy.py --port $PROXY_PORT --client_type=$CLIENT_TYPE &
proxyPID=$!
popd

Expand Down
16 changes: 15 additions & 1 deletion google/cloud/bigtable/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,30 @@
from google.cloud.bigtable.data._helpers import RowKeySamples
from google.cloud.bigtable.data._helpers import ShardedQuery

# setup custom CrossSync mappings for library
from google.cloud.bigtable_v2.services.bigtable.async_client import (
BigtableAsyncClient,
)
from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync
from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync

from google.cloud.bigtable.data._cross_sync import CrossSync

CrossSync.add_mapping("GapicClient", BigtableAsyncClient)
CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync)
CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync)
CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync)


__version__: str = package_version.__version__

__all__ = (
"BigtableDataClientAsync",
"TableAsync",
"MutationsBatcherAsync",
"RowKeySamples",
"ReadRowsQuery",
"RowRange",
"MutationsBatcherAsync",
"Mutation",
"RowMutationEntry",
"SetCell",
Expand Down
40 changes: 22 additions & 18 deletions google/cloud/bigtable/data/_async/_mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,38 @@
from __future__ import annotations

from typing import Sequence, TYPE_CHECKING
from dataclasses import dataclass
import functools

from google.api_core import exceptions as core_exceptions
from google.api_core import retry as retries
import google.cloud.bigtable_v2.types.bigtable as types_pb
import google.cloud.bigtable.data.exceptions as bt_exceptions
from google.cloud.bigtable.data._helpers import _attempt_timeout_generator
from google.cloud.bigtable.data._helpers import _retry_exception_factory

# mutate_rows requests are limited to this number of mutations
from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT
from google.cloud.bigtable.data.mutations import _EntryWithProto

from google.cloud.bigtable.data._cross_sync import CrossSync

if TYPE_CHECKING:
from google.cloud.bigtable_v2.services.bigtable.async_client import (
BigtableAsyncClient,
)
from google.cloud.bigtable.data.mutations import RowMutationEntry
from google.cloud.bigtable.data._async.client import TableAsync


@dataclass
class _EntryWithProto:
"""
A dataclass to hold a RowMutationEntry and its corresponding proto representation.
"""
if CrossSync.is_async:
from google.cloud.bigtable_v2.services.bigtable.async_client import (
BigtableAsyncClient as GapicClientType,
)
from google.cloud.bigtable.data._async.client import TableAsync as TableType
else:
from google.cloud.bigtable_v2.services.bigtable.client import ( # type: ignore
BigtableClient as GapicClientType,
)
from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore

entry: RowMutationEntry
proto: types_pb.MutateRowsRequest.Entry
__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._mutate_rows"


@CrossSync.convert_class("_MutateRowsOperation")
class _MutateRowsOperationAsync:
"""
MutateRowsOperation manages the logic of sending a set of row mutations,
Expand All @@ -65,10 +66,11 @@ class _MutateRowsOperationAsync:
If not specified, the request will run until operation_timeout is reached.
"""

@CrossSync.convert
def __init__(
self,
gapic_client: "BigtableAsyncClient",
table: "TableAsync",
gapic_client: GapicClientType,
table: TableType,
mutation_entries: list["RowMutationEntry"],
operation_timeout: float,
attempt_timeout: float | None,
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
bt_exceptions._MutateRowsIncomplete,
)
sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60)
self._operation = retries.retry_target_async(
self._operation = lambda: CrossSync.retry_target(
self._run_attempt,
self.is_retryable,
sleep_generator,
Expand All @@ -112,6 +114,7 @@ def __init__(
self.remaining_indices = list(range(len(self.mutations)))
self.errors: dict[int, list[Exception]] = {}

@CrossSync.convert
async def start(self):
"""
Start the operation, and run until completion
Expand All @@ -121,7 +124,7 @@ async def start(self):
"""
try:
# trigger mutate_rows
await self._operation
await self._operation()
except Exception as exc:
# exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations
incomplete_indices = self.remaining_indices.copy()
Expand All @@ -148,6 +151,7 @@ async def start(self):
all_errors, len(self.mutations)
)

@CrossSync.convert
async def _run_attempt(self):
"""
Run a single attempt of the mutate_rows rpc.
Expand Down
46 changes: 24 additions & 22 deletions google/cloud/bigtable/data/_async/_read_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,7 @@

from __future__ import annotations

from typing import (
TYPE_CHECKING,
AsyncGenerator,
AsyncIterable,
Awaitable,
Sequence,
)
from typing import Sequence, TYPE_CHECKING

from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB
from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB
Expand All @@ -32,21 +26,25 @@
from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
from google.cloud.bigtable.data.exceptions import InvalidChunk
from google.cloud.bigtable.data.exceptions import _RowSetComplete
from google.cloud.bigtable.data.exceptions import _ResetRow
from google.cloud.bigtable.data._helpers import _attempt_timeout_generator
from google.cloud.bigtable.data._helpers import _retry_exception_factory

from google.api_core import retry as retries
from google.api_core.retry import exponential_sleep_generator

if TYPE_CHECKING:
from google.cloud.bigtable.data._async.client import TableAsync
from google.cloud.bigtable.data._cross_sync import CrossSync

if TYPE_CHECKING:
if CrossSync.is_async:
from google.cloud.bigtable.data._async.client import TableAsync as TableType
else:
from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore

class _ResetRow(Exception):
def __init__(self, chunk):
self.chunk = chunk
__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._read_rows"


@CrossSync.convert_class("_ReadRowsOperation")
class _ReadRowsOperationAsync:
"""
ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream
Expand Down Expand Up @@ -80,7 +78,7 @@ class _ReadRowsOperationAsync:
def __init__(
self,
query: ReadRowsQuery,
table: "TableAsync",
table: TableType,
operation_timeout: float,
attempt_timeout: float,
retryable_exceptions: Sequence[type[Exception]] = (),
Expand All @@ -102,22 +100,22 @@ def __init__(
self._last_yielded_row_key: bytes | None = None
self._remaining_count: int | None = self.request.rows_limit or None

def start_operation(self) -> AsyncGenerator[Row, None]:
def start_operation(self) -> CrossSync.Iterable[Row]:
"""
Start the read_rows operation, retrying on retryable errors.
Yields:
Row: The next row in the stream
"""
return retries.retry_target_stream_async(
return CrossSync.retry_target_stream(
self._read_rows_attempt,
self._predicate,
exponential_sleep_generator(0.01, 60, multiplier=2),
self.operation_timeout,
exception_factory=_retry_exception_factory,
)

def _read_rows_attempt(self) -> AsyncGenerator[Row, None]:
def _read_rows_attempt(self) -> CrossSync.Iterable[Row]:
"""
Attempt a single read_rows rpc call.
This function is intended to be wrapped by retry logic,
Expand Down Expand Up @@ -152,9 +150,10 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]:
chunked_stream = self.chunk_stream(gapic_stream)
return self.merge_rows(chunked_stream)

@CrossSync.convert()
async def chunk_stream(
self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]]
) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]:
self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]]
) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]:
"""
process chunks out of raw read_rows stream
Expand Down Expand Up @@ -204,9 +203,12 @@ async def chunk_stream(
current_key = None

@staticmethod
@CrossSync.convert(
replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"},
)
async def merge_rows(
chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None
) -> AsyncGenerator[Row, None]:
chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None,
) -> CrossSync.Iterable[Row]:
"""
Merge chunks into rows
Expand All @@ -222,7 +224,7 @@ async def merge_rows(
while True:
try:
c = await it.__anext__()
except StopAsyncIteration:
except CrossSync.StopIteration:
# stream complete
return
row_key = c.row_key
Expand Down Expand Up @@ -315,7 +317,7 @@ async def merge_rows(
):
raise InvalidChunk("reset row with data")
continue
except StopAsyncIteration:
except CrossSync.StopIteration:
raise InvalidChunk("premature end of stream")

@staticmethod
Expand Down
Loading

0 comments on commit 7ea3c23

Please sign in to comment.