Skip to content

Commit

Permalink
cleaning up rm_aio
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Oct 24, 2024
1 parent 8830375 commit 6be8180
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions .cross_sync/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,31 +149,38 @@ class RmAioFunctions(ast.NodeTransformer):
"""
Visits all calls marked with CrossSync.rm_aio, and removes asyncio keywords
"""
RM_AIO_FN_NAME = "rm_aio"
RM_AIO_CLASS_NAME = "CrossSync"

def __init__(self):
self.to_sync = AsyncToSync()

def _is_rm_aio_call(self, node) -> bool:
"""
Check if a node is a CrossSync.rm_aio call
"""
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
if node.func.attr == self.RM_AIO_FN_NAME and node.func.value.id == self.RM_AIO_CLASS_NAME:
return True
return False

def visit_Call(self, node):
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and \
node.func.attr == "rm_aio" and "CrossSync" in node.func.value.id:
if self._is_rm_aio_call(node):
return self.visit(self.to_sync.visit(node.args[0]))
return self.generic_visit(node)

def visit_AsyncWith(self, node):
"""
Async with statements are not fully wrapped by calls
`async with` statements can contain multiple async context managers.
If any of them contains a CrossSync.rm_aio statement, convert into standard `with` statement
"""
found_rmaio = False
for item in node.items:
if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and \
item.context_expr.func.attr == "rm_aio" and "CrossSync" in item.context_expr.func.value.id:
found_rmaio = True
break
if found_rmaio:
if any(self._is_rm_aio_call(item.context_expr) for item in node.items
):
new_node = ast.copy_location(
ast.With(
[self.generic_visit(item) for item in node.items],
[self.generic_visit(stmt) for stmt in node.body],
[self.visit(item) for item in node.items],
[self.visit(stmt) for stmt in node.body],
),
node,
)
Expand All @@ -185,8 +192,7 @@ def visit_AsyncFor(self, node):
Async for statements are not fully wrapped by calls
"""
it = node.iter
if isinstance(it, ast.Call) and isinstance(it.func, ast.Attribute) and isinstance(it.func.value, ast.Name) and \
it.func.attr == "rm_aio" and "CrossSync" in it.func.value.id:
if self._is_rm_aio_call(it):
return ast.copy_location(
ast.For(
self.visit(node.target),
Expand Down

0 comments on commit 6be8180

Please sign in to comment.