Skip to content

Commit

Permalink
moved add_mapping into metaclass
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Jul 24, 2024
1 parent 223f337 commit 3600a63
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 44 deletions.
64 changes: 64 additions & 0 deletions google/cloud/bigtable/data/_sync/cross_sync/_mapping_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any


class MappingMeta(type):
"""
Metaclass to provide add_mapping functionality, allowing users to add
custom attributes to derived classes at runtime.
Using a metaclass allows us to share functionality between CrossSync
and CrossSync._Sync_Impl, and it works better with mypy checks than
monkypatching
"""

# list of attributes that can be added to the derived class at runtime
_runtime_replacements: dict[tuple[MappingMeta, str], Any] = {}

def add_mapping(cls: MappingMeta, name: str, value: Any):
"""
Add a new attribute to the class, for replacing library-level symbols
Raises:
- AttributeError if the attribute already exists with a different value
"""
key = (cls, name)
old_value = cls._runtime_replacements.get(key)
if old_value is None:
cls._runtime_replacements[key] = value
elif old_value != value:
raise AttributeError(f"Conflicting assignments for CrossSync.{name}")

def add_mapping_decorator(cls: MappingMeta, name: str):
"""
Exposes add_mapping as a class decorator
"""

def decorator(wrapped_cls):
cls.add_mapping(name, wrapped_cls)
return wrapped_cls

return decorator

def __getattr__(cls: MappingMeta, name: str):
"""
Retrieve custom attributes
"""
key = (cls, name)
found = cls._runtime_replacements.get(key)
if found is not None:
return found
raise AttributeError(f"CrossSync has no attribute {name}")
47 changes: 3 additions & 44 deletions google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ async def async_func(self, arg: int) -> int:
```
"""


from __future__ import annotations

from typing import (
Expand Down Expand Up @@ -67,14 +66,15 @@ async def async_func(self, arg: int) -> int:
Pytest,
PytestFixture,
)
from ._mapping_meta import MappingMeta

if TYPE_CHECKING:
from typing_extensions import TypeAlias

T = TypeVar("T")


class CrossSync:
class CrossSync(metaclass=MappingMeta):
# support CrossSync.is_async to check if the current environment is async
is_async = True

Expand Down Expand Up @@ -105,23 +105,6 @@ class CrossSync:
PytestFixture.decorator
) # decorate test methods to run with pytest fixture

# list of attributes that can be added to the CrossSync class at runtime
_runtime_replacements: set[Any] = set()

@classmethod
def add_mapping(cls, name, value):
"""
Add a new attribute to the CrossSync class, for replacing library-level symbols
Raises:
- AttributeError if the attribute already exists with a different value
"""
if not hasattr(cls, name):
cls._runtime_replacements.add(name)
elif value != getattr(cls, name):
raise AttributeError(f"Conflicting assignments for CrossSync.{name}")
setattr(cls, name, value)

@classmethod
def Mock(cls, *args, **kwargs):
"""
Expand Down Expand Up @@ -256,7 +239,7 @@ def rm_aio(statement: Any) -> Any:
"""
return statement

class _Sync_Impl:
class _Sync_Impl(metaclass=MappingMeta):
"""
Provide sync versions of the async functions and types in CrossSync
"""
Expand All @@ -280,30 +263,6 @@ class _Sync_Impl:
Iterator: TypeAlias = typing.Iterator
Generator: TypeAlias = typing.Generator

_runtime_replacements: set[Any] = set()

@classmethod
def add_mapping_decorator(cls, name):
def decorator(wrapped_cls):
cls.add_mapping(name, wrapped_cls)
return wrapped_cls

return decorator

@classmethod
def add_mapping(cls, name, value):
"""
Add a new attribute to the CrossSync class, for replacing library-level symbols
Raises:
- AttributeError if the attribute already exists with a different value
"""
if not hasattr(cls, name):
cls._runtime_replacements.add(name)
elif value != getattr(cls, name):
raise AttributeError(f"Conflicting assignments for CrossSync.{name}")
setattr(cls, name, value)

@classmethod
def Mock(cls, *args, **kwargs):
# try/except added for compatibility with python < 3.8
Expand Down

0 comments on commit 3600a63

Please sign in to comment.