diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 563fccb3b..0d43c1bb4 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -63,8 +63,8 @@ CrossSync provides a set of annotations to mark up async classes, to guide the g ### Code Generation Generation can be initiated using `python .cross_sync/generate.py .` -from the root of the project. This will find all classes with the `@CrossSync.export_sync` annotation -in both `/google` and `/tests` directories, and save them to their specified output paths +from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` +annotation, and generate a sync version of classes marked with `@CrossSync.export_sync` at the output path. ## Architecture diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index d93838e59..fea4f04f3 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -14,56 +14,30 @@ from __future__ import annotations from typing import Sequence import ast -from dataclasses import dataclass, field """ Entrypoint for initiating an async -> sync conversion using CrossSync Finds all python files rooted in a given directory, and uses -transformers.CrossSyncClassDecoratorHandler to handle any CrossSync class -decorators found in the files. +transformers.CrossSyncFileHandler to handle any files marked with +__CROSS_SYNC_OUTPUT__ """ -@dataclass class CrossSyncOutputFile: - """ - Represents an output file location. - Multiple decorated async classes may point to the same output location for - their generated sync code. This class holds all the information needed to - write the output file to disk. - """ + def __init__(self, file_path: str, ast_tree): + self.file_path = file_path + self.tree = ast_tree - # The path to the output file - file_path: str - # The import headers to write to the top of the output file - # will be populated when CrossSync.export_sync(include_file_imports=True) - imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field( - default_factory=list - ) - # The set of sync ast.ClassDef nodes to write to the output file - converted_classes: list[ast.ClassDef] = field(default_factory=list) - # the set of classes contained in the file. Used to prevent duplicates - contained_classes: set[str] = field(default_factory=set) - # the set of mypy error codes to ignore at the file level - # configured using CrossSync.export_sync(mypy_ignore=["error_code"]) - mypy_ignore: list[str] = field(default_factory=list) - - def __hash__(self): - return hash(self.file_path) - - def __repr__(self): - return f"CrossSyncOutputFile({self.file_path}, classes={[c.name for c in self.converted_classes]})" - - def render(self, with_black=True, save_to_disk=False) -> str: + def render(self, with_black=True, save_to_disk: bool = False) -> str: """ - Render the output file as a string. + Render the file to a string, and optionally save to disk Args: with_black: whether to run the output through black before returning save_to_disk: whether to write the output to the file path """ - full_str = ( + header = ( "# Copyright 2024 Google LLC\n" "#\n" '# Licensed under the Apache License, Version 2.0 (the "License");\n' @@ -80,13 +54,7 @@ def render(self, with_black=True, save_to_disk=False) -> str: "#\n" "# This file is automatically generated by CrossSync. Do not edit manually.\n" ) - if self.mypy_ignore: - full_str += ( - f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' - ) - full_str += "\n".join([ast.unparse(node) for node in self.imports]) # type: ignore - full_str += "\n\n" - full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) # type: ignore + full_str = header + ast.unparse(self.converted) if with_black: import black # type: ignore import autoflake # type: ignore @@ -96,30 +64,33 @@ def render(self, with_black=True, save_to_disk=False) -> str: mode=black.FileMode(), ) if save_to_disk: - # create parent paths if needed import os - os.makedirs(os.path.dirname(self.file_path), exist_ok=True) - with open(self.file_path, "w") as f: + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + with open(self.output_path, "w") as f: f.write(full_str) return full_str def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: import glob - from transformers import CrossSyncClassDecoratorHandler + from transformers import CrossSyncFileHandler # find all python files in the directory files = glob.glob(directory + "/**/*.py", recursive=True) # keep track of the output files pointed to by the annotated classes artifacts: set[CrossSyncOutputFile] = set() + file_transformer = CrossSyncFileHandler() # run each file through ast transformation to find all annotated classes - for file in files: - converter = CrossSyncClassDecoratorHandler(file) - new_outputs = converter.convert_file(artifacts) - artifacts.update(new_outputs) + for file_path in files: + file = open(file_path).read() + converted_tree = file_transformer.visit(ast.parse(file)) + if converted_tree is not None: + # contains __CROSS_SYNC_OUTPUT__ annotation + artifacts.add(CrossSyncOutputFile(file_path, converted_tree)) # return set of output artifacts return artifacts + def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): for a in artifacts: a.render(save_to_disk=True) @@ -130,5 +101,5 @@ def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): search_root = sys.argv[1] outputs = convert_files_in_dir(search_root) - print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}") + print(f"Generated {len(outputs)} artifacts: {[a.file_name for a in outputs]}") save_artifacts(outputs) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index b6a34c690..5afef0d41 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -32,7 +32,6 @@ # add cross_sync to path sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") from _decorators import AstDecorator, ExportSync -from generate import CrossSyncOutputFile class SymbolReplacer(ast.NodeTransformer): @@ -227,96 +226,54 @@ def visit_AsyncFunctionDef(self, node): raise ValueError(f"node {node.name} failed") from e -class CrossSyncClassDecoratorHandler(ast.NodeTransformer): +class CrossSyncFileHandler(ast.NodeTransformer): """ - Visits each class in the file, and if it has a CrossSync decorator, it will be transformed. - - Uses CrossSyncMethodDecoratorHandler to visit and (potentially) convert each method in the class + Visit each file, and process CrossSync classes if found """ - def __init__(self, file_path): - self.in_path = file_path - self._artifact_dict: dict[str, CrossSyncOutputFile] = {} - self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] - self.cross_sync_symbol_transformer = SymbolReplacer( - {"CrossSync": "CrossSync._Sync_Impl"} - ) - self.cross_sync_method_handler = CrossSyncMethodDecoratorHandler() - - def convert_file( - self, artifacts: set[CrossSyncOutputFile] | None = None - ) -> set[CrossSyncOutputFile]: - """ - Called to run a file through the ast transformer. - If the file contains any classes marked with CrossSync.export_sync, the - classes will be processed according to the decorator arguments, and - a set of CrossSyncOutputFile objects will be returned for each output file. - - If no CrossSync annotations are found, no changes will occur and an - empty set will be returned - """ - tree = ast.parse(open(self.in_path).read()) - self._artifact_dict = {f.file_path: f for f in artifacts or []} - self.imports = self._get_imports(tree) - self.visit(tree) - # return set of new artifacts - return set(self._artifact_dict.values()).difference(artifacts or []) + @staticmethod + def _find_cs_output(node): + for i, n in enumerate(node.body): + if isinstance(n, ast.Assign): + for target in n.targets: + if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__": + # keep the output path + # remove the statement + node.body.pop(i) + return n.value.value + ".py" + + def visit_Module(self, node): + # look for __CROSS_SYNC_OUTPUT__ Assign statement + self.output_path = self._find_cs_output(node) + if self.output_path: + # if found, process the file + return self.generic_visit(node) + else: + # not cross_sync file. Return None + return None def visit_ClassDef(self, node): """ Called for each class in file. If class has a CrossSync decorator, it will be transformed - according to the decorator arguments. Otherwise, no changes will occur - - Uses a set of CrossSyncOutputFile objects to store the transformed classes - and avoid duplicate writes + according to the decorator arguments. Otherwise, class is returned unchanged """ - try: - converted = None - for decorator in node.decorator_list: - try: - handler = AstDecorator.get_for_node(decorator) - if isinstance(handler, ExportSync): - # find the path to write the sync class to - out_file = "/".join(handler.path.rsplit(".")[:-1]) + ".py" - sync_cls_name = handler.path.rsplit(".", 1)[-1] - # find the artifact file for the save location - output_artifact = self._artifact_dict.get( - out_file, CrossSyncOutputFile(out_file) - ) - # write converted class details if not already present - if sync_cls_name not in output_artifact.contained_classes: - # transformation is handled in sync_ast_transform method of the decorator - converted = handler.sync_ast_transform(node, globals()) - output_artifact.converted_classes.append(converted) - # handle file-level mypy ignores - mypy_ignores = [ - s - for s in handler.mypy_ignore - if s not in output_artifact.mypy_ignore - ] - output_artifact.mypy_ignore.extend(mypy_ignores) - # handle file-level imports - if not output_artifact.imports and handler.include_file_imports: - output_artifact.imports = self.imports - self._artifact_dict[out_file] = output_artifact - except ValueError: - continue - return converted - except ValueError as e: - raise ValueError(f"failed for class: {node.name}") from e + for decorator in node.decorator_list: + try: + handler = AstDecorator.get_for_node(decorator) + if isinstance(handler, ExportSync): + # transformation is handled in sync_ast_transform method of the decorator + return handler.sync_ast_transform(node, globals()) + except ValueError: + # not cross_sync decorator + continue + # cross_sync decorator not found + return node - def _get_imports( - self, tree: ast.Module - ) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: + def visit_If(self, node): """ - Grab the imports from the top of the file - - raw imports, as well as try and if statements at the top level are included + remove CrossSync.is_async branches from top-level if statements """ - imports = [] - for node in tree.body: - if isinstance(node, (ast.Import, ast.ImportFrom, ast.Try, ast.If)): - imports.append(self.cross_sync_symbol_transformer.visit(node)) - return imports - + if isinstance(node.test, ast.Attribute) and isinstance(node.test.value, ast.Name) and node.test.value.id == "CrossSync" and node.test.attr == "is_async": + return node.orelse + return self.generic_visit(node) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index b1456951b..35219bee1 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: import ast - from typing import Sequence, Callable, Any + from typing import Callable, Any class AstDecorator: @@ -175,25 +175,21 @@ class ExportSync(AstDecorator): Class decorator for marking async classes to be converted to sync classes Args: - path: path to output the generated sync class + sync_name: use a new name for the sync class replace_symbols: a dict of symbols and replacements to use when generating sync class docstring_format_vars: a dict of variables to replace in the docstring - mypy_ignore: set of mypy errors to ignore in the generated file - include_file_imports: if True, include top-level imports from the file in the generated sync class add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync. """ def __init__( self, - path: str, + sync_name: str | None = None, *, replace_symbols: dict[str, str] | None = None, docstring_format_vars: dict[str, tuple[str, str]] | None = None, - mypy_ignore: Sequence[str] = (), - include_file_imports: bool = True, add_mapping_for_name: str | None = None, ): - self.path = path + self.sync_name = sync_name self.replace_symbols = replace_symbols docstring_format_vars = docstring_format_vars or {} self.async_docstring_format_vars = { @@ -202,8 +198,6 @@ def __init__( self.sync_docstring_format_vars = { k: v[1] for k, v in docstring_format_vars.items() } - self.mypy_ignore = mypy_ignore - self.include_file_imports = include_file_imports self.add_mapping_for_name = add_mapping_for_name def async_decorator(self): @@ -230,15 +224,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): import ast import copy - if not self.path: - raise ValueError( - f"{wrapped_node.name} has no path specified in export_sync decorator" - ) # copy wrapped node wrapped_node = copy.deepcopy(wrapped_node) # update name - sync_cls_name = self.path.rsplit(".", 1)[-1] - wrapped_node.name = sync_cls_name + if self.sync_name: + wrapped_node.name = self.sync_name # strip CrossSync decorators if hasattr(wrapped_node, "decorator_list"): wrapped_node.decorator_list = [ diff --git a/tests/system/cross_sync/test_cases/cross_sync_classes.yaml b/tests/system/cross_sync/test_cases/cross_sync_files.yaml similarity index 70% rename from tests/system/cross_sync/test_cases/cross_sync_classes.yaml rename to tests/system/cross_sync/test_cases/cross_sync_files.yaml index f38335e87..a49b8189e 100644 --- a/tests/system/cross_sync/test_cases/cross_sync_classes.yaml +++ b/tests/system/cross_sync/test_cases/cross_sync_files.yaml @@ -1,15 +1,40 @@ tests: - - description: "No conversion needed" + - description: "No output annotation" before: | - @CrossSync.export_sync(path="example.sync.MyClass") class MyAsyncClass: async def my_method(self): pass transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler + after: null + + - description: "CrossSync.export_sync with default sync_name" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync + class MyClass: + async def my_method(self): + pass + + transformers: + - name: CrossSyncFileHandler + after: | + class MyClass: + + async def my_method(self): + pass + + - description: "CrossSync.export_sync with custom sync_name" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass") + class MyAsyncClass: + async def my_method(self): + pass + + transformers: + - name: CrossSyncFileHandler after: | class MyClass: @@ -18,8 +43,9 @@ tests: - description: "CrossSync.export_sync with replace_symbols" before: | + __CROSS_SYNC_OUTPUT__ = "out.path" @CrossSync.export_sync( - path="example.sync.MyClass", + sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase", "ParentA": "ParentB"} ) class MyAsyncClass(ParentA): @@ -27,9 +53,7 @@ tests: self.base = base transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass(ParentB): @@ -38,24 +62,24 @@ tests: - description: "CrossSync.export_sync with docstring formatting" before: | + __CROSS_SYNC_OUTPUT__ = "out.path" @CrossSync.export_sync( - path="example.sync.MyClass", + sync_name="MyClass", docstring_format_vars={"type": ("async", "sync")} ) class MyAsyncClass: """This is a {type} class.""" transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass: """This is a sync class.""" - description: "CrossSync.export_sync with multiple decorators and methods" before: | - @CrossSync.export_sync(path="example.sync.MyClass") + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass") @some_other_decorator class MyAsyncClass: @CrossSync.convert @@ -75,9 +99,7 @@ tests: pass transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | @some_other_decorator class MyClass: @@ -95,7 +117,8 @@ tests: - description: "CrossSync.export_sync with nested classes" before: | - @CrossSync.export_sync(path="example.sync.MyClass", replace_symbols={"AsyncBase": "SyncBase"}) + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase"}) class MyAsyncClass: class NestedAsyncClass: async def nested_method(self, base: AsyncBase): @@ -110,9 +133,7 @@ tests: nested = self.NestedAsyncClass() CrossSync.rm_aio(await nested.nested_method()) transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass: @@ -127,8 +148,9 @@ tests: - description: "CrossSync.export_sync with add_mapping" before: | + __CROSS_SYNC_OUTPUT__ = "out.path" @CrossSync.export_sync( - path="example.sync.MyClass", + sync_name="MyClass", add_mapping_for_name="MyClass" ) class MyAsyncClass: @@ -136,9 +158,7 @@ tests: pass transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | @CrossSync._Sync_Impl.add_mapping_decorator("MyClass") class MyClass: @@ -148,7 +168,8 @@ tests: - description: "CrossSync.export_sync with CrossSync calls" before: | - @CrossSync.export_sync(path="example.sync.MyClass") + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass") class MyAsyncClass: @CrossSync.convert async def my_method(self): @@ -156,9 +177,7 @@ tests: CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass: diff --git a/tests/system/cross_sync/test_cross_sync_e2e.py b/tests/system/cross_sync/test_cross_sync_e2e.py index 489e042fe..bd08ed6cb 100644 --- a/tests/system/cross_sync/test_cross_sync_e2e.py +++ b/tests/system/cross_sync/test_cross_sync_e2e.py @@ -15,7 +15,7 @@ AsyncToSync, RmAioFunctions, CrossSyncMethodDecoratorHandler, - CrossSyncClassDecoratorHandler, + CrossSyncFileHandler, ) @@ -42,7 +42,7 @@ def loader(): sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher" ) def test_e2e_scenario(test_dict): - before_ast = ast.parse(test_dict["before"]).body[0] + before_ast = ast.parse(test_dict["before"]) got_ast = before_ast for transformer_info in test_dict["transformers"]: # transformer can be passed as a string, or a dict with name and args @@ -54,6 +54,12 @@ def test_e2e_scenario(test_dict): transformer_args = transformer_info.get("args", {}) transformer = transformer_class(**transformer_args) got_ast = transformer.visit(got_ast) - final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) - expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) + if got_ast is None: + final_str = "" + else: + final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) + if test_dict.get("after") is None: + expected_str = "" + else: + expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) assert final_str == expected_str, f"Expected:\n{expected_str}\nGot:\n{final_str}" diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index 988d8d113..6c817fd9d 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -45,39 +45,27 @@ def test_ctor_defaults(self): """ Should set default values for path, add_mapping_for_name, and docstring_format_vars """ - with pytest.raises(TypeError) as exc: - self._get_class()() - assert "missing 1 required positional argument" in str(exc.value) - path = object() - instance = self._get_class()(path) - assert instance.path is path + instance = self._get_class()() + assert instance.sync_name is None assert instance.replace_symbols is None - assert instance.mypy_ignore == () - assert instance.include_file_imports is True assert instance.add_mapping_for_name is None assert instance.async_docstring_format_vars == {} assert instance.sync_docstring_format_vars == {} def test_ctor(self): - path = object() + sync_name = "sync_name" replace_symbols = {"a": "b"} docstring_format_vars = {"A": (1, 2)} - mypy_ignore = ("a", "b") - include_file_imports = False add_mapping_for_name = "test_name" instance = self._get_class()( - path=path, + sync_name, replace_symbols=replace_symbols, docstring_format_vars=docstring_format_vars, - mypy_ignore=mypy_ignore, - include_file_imports=include_file_imports, add_mapping_for_name=add_mapping_for_name, ) - assert instance.path is path + assert instance.sync_name is sync_name assert instance.replace_symbols is replace_symbols - assert instance.mypy_ignore is mypy_ignore - assert instance.include_file_imports is include_file_imports assert instance.add_mapping_for_name is add_mapping_for_name assert instance.async_docstring_format_vars == {"A": 1} assert instance.sync_docstring_format_vars == {"A": 2} @@ -87,7 +75,7 @@ def test_class_decorator(self): Should return class being decorated """ unwrapped_class = mock.Mock - wrapped_class = self._get_class().decorator(unwrapped_class, path=1) + wrapped_class = self._get_class().decorator(unwrapped_class, sync_name="s") assert unwrapped_class == wrapped_class def test_class_decorator_adds_mapping(self): @@ -97,11 +85,13 @@ def test_class_decorator_adds_mapping(self): with mock.patch.object(CrossSync, "add_mapping") as add_mapping: mock_cls = mock.Mock # check decoration with no add_mapping - self._get_class().decorator(path=1)(mock_cls) + self._get_class().decorator(sync_name="s")(mock_cls) assert add_mapping.call_count == 0 # check decoration with add_mapping name = "test_name" - self._get_class().decorator(path=1, add_mapping_for_name=name)(mock_cls) + self._get_class().decorator(sync_name="s", add_mapping_for_name=name)( + mock_cls + ) assert add_mapping.call_count == 1 add_mapping.assert_called_once_with(name, mock_cls) @@ -122,13 +112,13 @@ def test_class_decorator_docstring_update(self, docstring, format_vars, expected of the class being decorated """ - @ExportSync.decorator(path=1, docstring_format_vars=format_vars) + @ExportSync.decorator(sync_name="s", docstring_format_vars=format_vars) class Class: __doc__ = docstring assert Class.__doc__ == expected # check internal state - instance = self._get_class()(path=1, docstring_format_vars=format_vars) + instance = self._get_class()(sync_name="s", docstring_format_vars=format_vars) async_replacements = {k: v[0] for k, v in format_vars.items()} sync_replacements = {k: v[1] for k, v in format_vars.items()} assert instance.async_docstring_format_vars == async_replacements @@ -138,7 +128,7 @@ def test_sync_ast_transform_replaces_name(self, globals_mock): """ Should update the name of the new class """ - decorator = self._get_class()("path.to.SyncClass") + decorator = self._get_class()("SyncClass") mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) result = decorator.sync_ast_transform(mock_node, globals_mock)