Skip to content

Commit

Permalink
use add_mapping in place of replace_symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Jul 12, 2024
1 parent fd1fb71 commit adb092e
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 49 deletions.
13 changes: 5 additions & 8 deletions google/cloud/bigtable/data/_async/_mutate_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from google.cloud.bigtable_v2.services.bigtable.async_client import (
BigtableAsyncClient,
)
CrossSync.add_mapping("Table", TableAsync)
CrossSync.add_mapping("GapicClient", BigtableAsyncClient)


@CrossSync.export_sync(
Expand All @@ -62,16 +64,11 @@ class _MutateRowsOperationAsync:
If not specified, the request will run until operation_timeout is reached.
"""

@CrossSync.convert(
replace_symbols={
"BigtableAsyncClient": "BigtableClient",
"TableAsync": "Table",
}
)
@CrossSync.convert
def __init__(
self,
gapic_client: "BigtableAsyncClient",
table: "TableAsync",
gapic_client: "CrossSync.GapicClient",
table: "CrossSync.Table",
mutation_entries: list["RowMutationEntry"],
operation_timeout: float,
attempt_timeout: float | None,
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/bigtable/data/_async/_read_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
if TYPE_CHECKING:
if CrossSync.is_async:
from google.cloud.bigtable.data._async.client import TableAsync
CrossSync.add_mapping("Table", TableAsync)


@CrossSync.export_sync(
Expand Down Expand Up @@ -79,11 +80,10 @@ class _ReadRowsOperationAsync:
"_remaining_count",
)

@CrossSync.convert(replace_symbols={"TableAsync": "Table"})
def __init__(
self,
query: ReadRowsQuery,
table: "TableAsync",
table: "CrossSync.Table",
operation_timeout: float,
attempt_timeout: float,
retryable_exceptions: Sequence[type[Exception]] = (),
Expand Down
38 changes: 16 additions & 22 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@
from google.cloud.bigtable_v2.services.bigtable.async_client import (
BigtableAsyncClient,
)
# define file-specific cross-sync replacements
CrossSync.add_mapping("GapicClient", BigtableAsyncClient)
CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport)
CrossSync.add_mapping("PooledChannel", AsyncPooledChannel)
CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync)
CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync)


if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
Expand All @@ -101,13 +108,7 @@
path="google.cloud.bigtable.data._sync.client.BigtableDataClient",
)
class BigtableDataClientAsync(ClientWithProject):
@CrossSync.convert(
replace_symbols={
"BigtableAsyncClient": "BigtableClient",
"PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport",
"AsyncPooledChannel": "PooledChannel",
}
)
@CrossSync.convert
def __init__(
self,
*,
Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(
"""
# set up transport in registry
transport_str = f"bt-{self._client_version()}-{pool_size}"
transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size)
transport = CrossSync.PooledTransport.with_fixed_size(pool_size)
BigtableClientMeta._transport_registry[transport_str] = transport
# set up client info headers for veneer library
client_info = DEFAULT_CLIENT_INFO
Expand All @@ -168,15 +169,15 @@ def __init__(
project=project,
client_options=client_options,
)
self._gapic_client = BigtableAsyncClient(
self._gapic_client = CrossSync.GapicClient(
transport=transport_str,
credentials=credentials,
client_options=client_options,
client_info=client_info,
)
self._is_closed = CrossSync.Event()
self.transport = cast(
PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport
CrossSync.PooledTransport, self._gapic_client.transport
)
# keep track of active instances to for warmup on channel refresh
self._active_instances: Set[_WarmedInstanceKey] = set()
Expand All @@ -195,7 +196,7 @@ def __init__(
RuntimeWarning,
stacklevel=2,
)
self.transport._grpc_channel = AsyncPooledChannel(
self.transport._grpc_channel = CrossSync.PooledChannel(
pool_size=pool_size,
host=self._emulator_host,
insecure=True,
Expand Down Expand Up @@ -611,12 +612,7 @@ def __init__(
f"{self.__class__.__name__} must be created within an async event loop context."
) from e

@CrossSync.convert(
replace_symbols={
"AsyncIterable": "Iterable",
"_ReadRowsOperationAsync": "_ReadRowsOperation",
}
)
@CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable"})
async def read_rows_stream(
self,
query: ReadRowsQuery,
Expand Down Expand Up @@ -658,7 +654,7 @@ async def read_rows_stream(
)
retryable_excs = _get_retryable_errors(retryable_errors, self)

row_merger = _ReadRowsOperationAsync(
row_merger = CrossSync._ReadRowsOperation(
query,
self,
operation_timeout=operation_timeout,
Expand Down Expand Up @@ -1116,9 +1112,7 @@ async def mutate_row(
exception_factory=_retry_exception_factory,
)

@CrossSync.convert(
replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"}
)
@CrossSync.convert
async def bulk_mutate_rows(
self,
mutation_entries: list[RowMutationEntry],
Expand Down Expand Up @@ -1164,7 +1158,7 @@ async def bulk_mutate_rows(
)
retryable_excs = _get_retryable_errors(retryable_errors, self)

operation = _MutateRowsOperationAsync(
operation = CrossSync._MutateRowsOperation(
self.client._gapic_client,
self,
mutation_entries,
Expand Down
8 changes: 3 additions & 5 deletions google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

if CrossSync.is_async:
from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync

CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync)

if TYPE_CHECKING:
from google.cloud.bigtable.data.mutations import RowMutationEntry
Expand Down Expand Up @@ -361,9 +361,7 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]):
self._entries_processed_since_last_raise += len(new_entries)
self._add_exceptions(found_exceptions)

@CrossSync.convert(
replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"}
)
@CrossSync.convert
async def _execute_mutate_rows(
self, batch: list[RowMutationEntry]
) -> list[FailedMutationEntryError]:
Expand All @@ -380,7 +378,7 @@ async def _execute_mutate_rows(
FailedMutationEntryError objects will not contain index information
"""
try:
operation = _MutateRowsOperationAsync(
operation = CrossSync._MutateRowsOperation(
self._table.client._gapic_client,
self._table,
batch,
Expand Down
17 changes: 17 additions & 0 deletions google/cloud/bigtable/data/_sync/cross_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,21 @@ class CrossSync(metaclass=_DecoratorMeta):
Iterator: TypeAlias = AsyncIterator
Generator: TypeAlias = AsyncGenerator

@classmethod
def add_mapping(cls, name, value):
"""
Add a new attribute to the CrossSync class, for replacing library-level symbols
Raises:
- AttributeError if the attribute already exists with a different value
"""
if not hasattr(cls, name):
cls._runtime_replacements.add(name)
elif value != getattr(cls, name):
raise AttributeError(f"Conflicting assignments for CrossSync.{name}")
setattr(cls, name, value)

# list of decorators that can be applied to classes and methods to guide code generation
_decorators: list[AstDecorator] = [
AstDecorator("export_sync", # decorate classes to convert
required_keywords=["path"], # otput path for generated sync class
Expand All @@ -244,6 +259,8 @@ class CrossSync(metaclass=_DecoratorMeta):
name=None,
),
]
# list of attributes that can be added to the CrossSync class at runtime
_runtime_replacements: set[Any] = set()

@classmethod
def Mock(cls, *args, **kwargs):
Expand Down
14 changes: 2 additions & 12 deletions tests/unit/data/_async/test_mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,11 +938,7 @@ async def test_timer_flush_end_to_end(self):

@CrossSync.pytest
async def test__execute_mutate_rows(self):
if CrossSync.is_async:
mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync"
else:
mutate_path = "_sync.mutations_batcher._MutateRowsOperation"
with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows:
with mock.patch.object(CrossSync, "_MutateRowsOperation") as mutate_rows:
mutate_rows.return_value = CrossSync.Mock()
start_operation = mutate_rows().start
table = mock.Mock()
Expand Down Expand Up @@ -1105,13 +1101,7 @@ async def test_timeout_args_passed(self):
batch_operation_timeout and batch_attempt_timeout should be used
in api calls
"""
if CrossSync.is_async:
mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync"
else:
mutate_path = "_sync.mutations_batcher._MutateRowsOperation"
with mock.patch(
f"google.cloud.bigtable.data.{mutate_path}", return_value=CrossSync.Mock()
) as mutate_rows:
with mock.patch.object(CrossSync, "_MutateRowsOperation", return_value=CrossSync.Mock()) as mutate_rows:
expected_operation_timeout = 17
expected_attempt_timeout = 13
async with self._make_one(
Expand Down

0 comments on commit adb092e

Please sign in to comment.