Skip to content

Commit

Permalink
fixed lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 6, 2024
1 parent 7a0c638 commit a3cb9a6
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 90 deletions.
30 changes: 23 additions & 7 deletions google/cloud/bigtable/data/_sync/cross_sync/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ def __init__(
self.path = path
self.replace_symbols = replace_symbols
docstring_format_vars = docstring_format_vars or {}
self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()}
self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()}
self.async_docstring_format_vars = {
k: v[0] for k, v in docstring_format_vars.items()
}
self.sync_docstring_format_vars = {
k: v[1] for k, v in docstring_format_vars.items()
}
self.mypy_ignore = mypy_ignore
self.include_file_imports = include_file_imports
self.add_mapping_for_name = add_mapping_for_name
Expand Down Expand Up @@ -271,7 +275,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals):
# update docstring if specified
if self.sync_docstring_format_vars:
docstring = ast.get_docstring(wrapped_node)
wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars)
wrapped_node.body[0].value.s = docstring.format(
**self.sync_docstring_format_vars
)
return wrapped_node


Expand Down Expand Up @@ -299,8 +305,12 @@ def __init__(
self.sync_name = sync_name
self.replace_symbols = replace_symbols
docstring_format_vars = docstring_format_vars or {}
self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()}
self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()}
self.async_docstring_format_vars = {
k: v[0] for k, v in docstring_format_vars.items()
}
self.sync_docstring_format_vars = {
k: v[1] for k, v in docstring_format_vars.items()
}
self.rm_aio = rm_aio

def sync_ast_transform(self, wrapped_node, transformers_globals):
Expand All @@ -315,7 +325,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals):
wrapped_node.name,
wrapped_node.args,
wrapped_node.body,
wrapped_node.decorator_list if hasattr(wrapped_node, "decorator_list") else [],
wrapped_node.decorator_list
if hasattr(wrapped_node, "decorator_list")
else [],
wrapped_node.returns if hasattr(wrapped_node, "returns") else None,
),
wrapped_node,
Expand All @@ -333,7 +345,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals):
# update docstring if specified
if self.sync_docstring_format_vars:
docstring = ast.get_docstring(wrapped_node)
wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars)
wrapped_node.body[0].value.s = docstring.format(
**self.sync_docstring_format_vars
)
return wrapped_node

def async_decorator(self):
Expand All @@ -342,9 +356,11 @@ def async_decorator(self):
"""

if self.async_docstring_format_vars:

def decorator(f):
f.__doc__ = f.__doc__.format(**self.async_docstring_format_vars)
return f

return decorator
else:
return None
Expand Down
14 changes: 12 additions & 2 deletions tests/system/cross_sync/test_cross_sync_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
import black
import pytest
import yaml

# add cross_sync to path
test_dir_name = os.path.dirname(__file__)
cross_sync_path = os.path.join(test_dir_name, "..", "..", "..", ".cross_sync")
sys.path.append(cross_sync_path)

from transformers import SymbolReplacer, AsyncToSync, RmAioFunctions, CrossSyncMethodDecoratorHandler, CrossSyncClassDecoratorHandler
from transformers import ( # noqa: F401 E402
SymbolReplacer,
AsyncToSync,
RmAioFunctions,
CrossSyncMethodDecoratorHandler,
CrossSyncClassDecoratorHandler,
)


def loader():
Expand All @@ -27,10 +34,13 @@ def loader():
test["file_name"] = file_name
yield test


@pytest.mark.parametrize(
"test_dict", loader(), ids=lambda x: f"{x['file_name']}: {x.get('description', '')}"
)
@pytest.mark.skipif(sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher")
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher"
)
def test_e2e_scenario(test_dict):
before_ast = ast.parse(test_dict["before"]).body[0]
got_ast = before_ast
Expand Down
63 changes: 44 additions & 19 deletions tests/unit/data/_sync/test_cross_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync, T
from unittest import mock

class TestCrossSync:

class TestCrossSync:
async def async_iter(self, in_list):
for i in in_list:
yield i
Expand All @@ -38,14 +38,22 @@ def cs_sync(self):
def cs_async(self):
return CrossSync


@pytest.mark.parametrize(
"attr, async_version, sync_version", [
"attr, async_version, sync_version",
[
("is_async", True, False),
("sleep", asyncio.sleep, time.sleep),
("wait", asyncio.wait, concurrent.futures.wait),
("retry_target", api_core.retry.retry_target_async, api_core.retry.retry_target),
("retry_target_stream", api_core.retry.retry_target_stream_async, api_core.retry.retry_target_stream),
(
"retry_target",
api_core.retry.retry_target_async,
api_core.retry.retry_target,
),
(
"retry_target_stream",
api_core.retry.retry_target_stream_async,
api_core.retry.retry_target_stream,
),
("Retry", api_core.retry.AsyncRetry, api_core.retry.Retry),
("Queue", asyncio.Queue, queue.Queue),
("Condition", asyncio.Condition, threading.Condition),
Expand All @@ -59,14 +67,18 @@ def cs_async(self):
("Iterable", typing.AsyncIterable, typing.Iterable),
("Iterator", typing.AsyncIterator, typing.Iterator),
("Generator", typing.AsyncGenerator, typing.Generator),
]
],
)
def test_alias_attributes(self, attr, async_version, sync_version, cs_sync, cs_async):
def test_alias_attributes(
self, attr, async_version, sync_version, cs_sync, cs_async
):
"""
Test basic alias attributes, to ensure they point to the right place
in both sync and async versions.
"""
assert getattr(cs_async, attr) == async_version, f"Failed async version for {attr}"
assert (
getattr(cs_async, attr) == async_version
), f"Failed async version for {attr}"
assert getattr(cs_sync, attr) == sync_version, f"Failed sync version for {attr}"

@pytest.mark.asyncio
Expand Down Expand Up @@ -121,7 +133,7 @@ def test_gather_partials_with_excepptions(self, cs_sync):
Test sync version of CrossSync.gather_partials() with exceptions
"""
with concurrent.futures.ThreadPoolExecutor() as e:
partials = [lambda i=i: i + 1 if i != 3 else 1/0 for i in range(5)]
partials = [lambda i=i: i + 1 if i != 3 else 1 / 0 for i in range(5)]
with pytest.raises(ZeroDivisionError):
cs_sync.gather_partials(partials, sync_executor=e)

Expand All @@ -130,8 +142,10 @@ def test_gather_partials_return_exceptions(self, cs_sync):
Test sync version of CrossSync.gather_partials() with return_exceptions=True
"""
with concurrent.futures.ThreadPoolExecutor() as e:
partials = [lambda i=i: i + 1 if i != 3 else 1/0 for i in range(5)]
results = cs_sync.gather_partials(partials, return_exceptions=True, sync_executor=e)
partials = [lambda i=i: i + 1 if i != 3 else 1 / 0 for i in range(5)]
results = cs_sync.gather_partials(
partials, return_exceptions=True, sync_executor=e
)
assert len(results) == 5
assert results[0] == 1
assert results[1] == 2
Expand All @@ -145,14 +159,15 @@ def test_gather_partials_no_executor(self, cs_sync):
"""
partials = [lambda i=i: i + 1 for i in range(5)]
with pytest.raises(ValueError) as e:
results = cs_sync.gather_partials(partials)
cs_sync.gather_partials(partials)
assert "sync_executor is required" in str(e.value)

@pytest.mark.asyncio
async def test_gather_partials_async(self, cs_async):
"""
Test async version of CrossSync.gather_partials()
"""

async def coro(i):
return i + 1

Expand All @@ -165,8 +180,9 @@ async def test_gather_partials_async_with_exceptions(self, cs_async):
"""
Test async version of CrossSync.gather_partials() with exceptions
"""

async def coro(i):
return i + 1 if i != 3 else 1/0
return i + 1 if i != 3 else 1 / 0

partials = [functools.partial(coro, i) for i in range(5)]
with pytest.raises(ZeroDivisionError):
Expand All @@ -177,8 +193,9 @@ async def test_gather_partials_async_return_exceptions(self, cs_async):
"""
Test async version of CrossSync.gather_partials() with return_exceptions=True
"""

async def coro(i):
return i + 1 if i != 3 else 1/0
return i + 1 if i != 3 else 1 / 0

partials = [functools.partial(coro, i) for i in range(5)]
results = await cs_async.gather_partials(partials, return_exceptions=True)
Expand All @@ -194,13 +211,16 @@ async def test_gather_partials_async_uses_asyncio_gather(self, cs_async):
"""
CrossSync.gather_partials() should use asyncio.gather() internally
"""

async def coro(i):
return i + 1

return_exceptions=object()
return_exceptions = object()
partials = [functools.partial(coro, i) for i in range(5)]
with mock.patch.object(asyncio, "gather", mock.AsyncMock()) as gather:
await cs_async.gather_partials(partials, return_exceptions=return_exceptions)
await cs_async.gather_partials(
partials, return_exceptions=return_exceptions
)
gather.assert_called_once()
found_args, found_kwargs = gather.call_args
assert found_kwargs["return_exceptions"] == return_exceptions
Expand Down Expand Up @@ -249,7 +269,6 @@ async def test_event_wait_async(self, cs_async, break_early):
await cs_async.event_wait(event, async_break_early=break_early)
event.wait.assert_called_once_with()


@pytest.mark.asyncio
async def test_event_wait_async_with_timeout(self, cs_async):
"""
Expand Down Expand Up @@ -308,7 +327,7 @@ def test_create_task(self, cs_sync):
Test creating Future using create_task()
"""
executor = concurrent.futures.ThreadPoolExecutor()
fn = lambda x, y: x + y
fn = lambda x, y: x + y # noqa: E731
result = cs_sync.create_task(fn, 1, y=4, sync_executor=executor)
assert isinstance(result, cs_sync.Task)
assert result.result() == 5
Expand All @@ -327,7 +346,6 @@ def test_create_task_passthrough(self, cs_sync):
assert executor.submit.call_count == 1
assert executor.submit.call_args == ((fn, *args), kwargs)


def test_create_task_no_executor(self, cs_sync):
"""
if no executor is provided, raise an exception
Expand All @@ -341,8 +359,10 @@ async def test_create_task_async(self, cs_async):
"""
Test creating Future using create_task()
"""

async def coro_fn(x, y):
return x + y

result = cs_async.create_task(coro_fn, 1, y=4)
assert isinstance(result, asyncio.Task)
assert await result == 5
Expand All @@ -358,6 +378,7 @@ async def test_create_task_async_passthrough(self, cs_async):
kwargs = {"a": 1, "b": 2}
with mock.patch.object(asyncio, "create_task", mock.Mock()) as create_task:
result = cs_async.create_task(coro_fn, *args, **kwargs)
assert isinstance(result, asyncio.Task)
create_task.assert_called_once()
create_task.assert_called_once_with(coro_fn.return_value)
coro_fn.assert_called_once_with(*args, **kwargs)
Expand All @@ -367,8 +388,10 @@ async def test_create_task_async_with_name(self, cs_async):
"""
Test creating a task with a name
"""

async def coro_fn():
return None

name = "test-name-456"
result = cs_async.create_task(coro_fn, task_name=name)
assert isinstance(result, asyncio.Task)
Expand Down Expand Up @@ -445,7 +468,9 @@ def test_add_mapping_decorator(self, cs_sync, cs_async):
add_mapping_decorator should allow wrapping classes with add_mapping()
"""
for cls in [cs_sync, cs_async]:

@cls.add_mapping_decorator("decorated")
class Decorated:
pass

assert cls.decorated == Decorated
Loading

0 comments on commit a3cb9a6

Please sign in to comment.