diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index d502c22e9..80e384361 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -149,31 +149,38 @@ class RmAioFunctions(ast.NodeTransformer): """ Visits all calls marked with CrossSync.rm_aio, and removes asyncio keywords """ + RM_AIO_FN_NAME = "rm_aio" + RM_AIO_CLASS_NAME = "CrossSync" def __init__(self): self.to_sync = AsyncToSync() + def _is_rm_aio_call(self, node) -> bool: + """ + Check if a node is a CrossSync.rm_aio call + """ + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.attr == self.RM_AIO_FN_NAME and node.func.value.id == self.RM_AIO_CLASS_NAME: + return True + return False + def visit_Call(self, node): - if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and \ - node.func.attr == "rm_aio" and "CrossSync" in node.func.value.id: + if self._is_rm_aio_call(node): return self.visit(self.to_sync.visit(node.args[0])) return self.generic_visit(node) def visit_AsyncWith(self, node): """ - Async with statements are not fully wrapped by calls + `async with` statements can contain multiple async context managers. + + If any of them contains a CrossSync.rm_aio statement, convert into standard `with` statement """ - found_rmaio = False - for item in node.items: - if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and \ - item.context_expr.func.attr == "rm_aio" and "CrossSync" in item.context_expr.func.value.id: - found_rmaio = True - break - if found_rmaio: + if any(self._is_rm_aio_call(item.context_expr) for item in node.items + ): new_node = ast.copy_location( ast.With( - [self.generic_visit(item) for item in node.items], - [self.generic_visit(stmt) for stmt in node.body], + [self.visit(item) for item in node.items], + [self.visit(stmt) for stmt in node.body], ), node, ) @@ -185,8 +192,7 @@ def visit_AsyncFor(self, node): Async for statements are not fully wrapped by calls """ it = node.iter - if isinstance(it, ast.Call) and isinstance(it.func, ast.Attribute) and isinstance(it.func.value, ast.Name) and \ - it.func.attr == "rm_aio" and "CrossSync" in it.func.value.id: + if self._is_rm_aio_call(it): return ast.copy_location( ast.For( self.visit(node.target),