Skip to content

Commit

Permalink
simplified file processing
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 11, 2024
1 parent 13abfd4 commit 7dc9a2b
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 206 deletions.
4 changes: 2 additions & 2 deletions .cross_sync/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ CrossSync provides a set of annotations to mark up async classes, to guide the g
### Code Generation

Generation can be initiated using `python .cross_sync/generate.py .`
from the root of the project. This will find all classes with the `@CrossSync.export_sync` annotation
in both `/google` and `/tests` directories, and save them to their specified output paths
from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"`
annotation, and generate a sync version of classes marked with `@CrossSync.export_sync` at the output path.

## Architecture

Expand Down
71 changes: 21 additions & 50 deletions .cross_sync/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,30 @@
from __future__ import annotations
from typing import Sequence
import ast
from dataclasses import dataclass, field
"""
Entrypoint for initiating an async -> sync conversion using CrossSync
Finds all python files rooted in a given directory, and uses
transformers.CrossSyncClassDecoratorHandler to handle any CrossSync class
decorators found in the files.
transformers.CrossSyncFileHandler to handle any files marked with
__CROSS_SYNC_OUTPUT__
"""


@dataclass
class CrossSyncOutputFile:
"""
Represents an output file location.

Multiple decorated async classes may point to the same output location for
their generated sync code. This class holds all the information needed to
write the output file to disk.
"""
def __init__(self, file_path: str, ast_tree):
self.file_path = file_path
self.tree = ast_tree

# The path to the output file
file_path: str
# The import headers to write to the top of the output file
# will be populated when CrossSync.export_sync(include_file_imports=True)
imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field(
default_factory=list
)
# The set of sync ast.ClassDef nodes to write to the output file
converted_classes: list[ast.ClassDef] = field(default_factory=list)
# the set of classes contained in the file. Used to prevent duplicates
contained_classes: set[str] = field(default_factory=set)
# the set of mypy error codes to ignore at the file level
# configured using CrossSync.export_sync(mypy_ignore=["error_code"])
mypy_ignore: list[str] = field(default_factory=list)

def __hash__(self):
return hash(self.file_path)

def __repr__(self):
return f"CrossSyncOutputFile({self.file_path}, classes={[c.name for c in self.converted_classes]})"

def render(self, with_black=True, save_to_disk=False) -> str:
def render(self, with_black=True, save_to_disk: bool = False) -> str:
"""
Render the output file as a string.
Render the file to a string, and optionally save to disk
Args:
with_black: whether to run the output through black before returning
save_to_disk: whether to write the output to the file path
"""
full_str = (
header = (
"# Copyright 2024 Google LLC\n"
"#\n"
'# Licensed under the Apache License, Version 2.0 (the "License");\n'
Expand All @@ -80,13 +54,7 @@ def render(self, with_black=True, save_to_disk=False) -> str:
"#\n"
"# This file is automatically generated by CrossSync. Do not edit manually.\n"
)
if self.mypy_ignore:
full_str += (
f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n'
)
full_str += "\n".join([ast.unparse(node) for node in self.imports]) # type: ignore
full_str += "\n\n"
full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) # type: ignore
full_str = header + ast.unparse(self.converted)
if with_black:
import black # type: ignore
import autoflake # type: ignore
Expand All @@ -96,30 +64,33 @@ def render(self, with_black=True, save_to_disk=False) -> str:
mode=black.FileMode(),
)
if save_to_disk:
# create parent paths if needed
import os
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
with open(self.file_path, "w") as f:
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
with open(self.output_path, "w") as f:
f.write(full_str)
return full_str


def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]:
import glob
from transformers import CrossSyncClassDecoratorHandler
from transformers import CrossSyncFileHandler

# find all python files in the directory
files = glob.glob(directory + "/**/*.py", recursive=True)
# keep track of the output files pointed to by the annotated classes
artifacts: set[CrossSyncOutputFile] = set()
file_transformer = CrossSyncFileHandler()
# run each file through ast transformation to find all annotated classes
for file in files:
converter = CrossSyncClassDecoratorHandler(file)
new_outputs = converter.convert_file(artifacts)
artifacts.update(new_outputs)
for file_path in files:
file = open(file_path).read()
converted_tree = file_transformer.visit(ast.parse(file))
if converted_tree is not None:
# contains __CROSS_SYNC_OUTPUT__ annotation
artifacts.add(CrossSyncOutputFile(file_path, converted_tree))
# return set of output artifacts
return artifacts


def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]):
for a in artifacts:
a.render(save_to_disk=True)
Expand All @@ -130,5 +101,5 @@ def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]):

search_root = sys.argv[1]
outputs = convert_files_in_dir(search_root)
print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}")
print(f"Generated {len(outputs)} artifacts: {[a.file_name for a in outputs]}")
save_artifacts(outputs)
121 changes: 39 additions & 82 deletions .cross_sync/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
# add cross_sync to path
sys.path.append("google/cloud/bigtable/data/_sync/cross_sync")
from _decorators import AstDecorator, ExportSync
from generate import CrossSyncOutputFile


class SymbolReplacer(ast.NodeTransformer):
Expand Down Expand Up @@ -227,96 +226,54 @@ def visit_AsyncFunctionDef(self, node):
raise ValueError(f"node {node.name} failed") from e


class CrossSyncClassDecoratorHandler(ast.NodeTransformer):
class CrossSyncFileHandler(ast.NodeTransformer):
"""
Visits each class in the file, and if it has a CrossSync decorator, it will be transformed.
Uses CrossSyncMethodDecoratorHandler to visit and (potentially) convert each method in the class
Visit each file, and process CrossSync classes if found
"""
def __init__(self, file_path):
self.in_path = file_path
self._artifact_dict: dict[str, CrossSyncOutputFile] = {}
self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = []
self.cross_sync_symbol_transformer = SymbolReplacer(
{"CrossSync": "CrossSync._Sync_Impl"}
)
self.cross_sync_method_handler = CrossSyncMethodDecoratorHandler()

def convert_file(
self, artifacts: set[CrossSyncOutputFile] | None = None
) -> set[CrossSyncOutputFile]:
"""
Called to run a file through the ast transformer.

If the file contains any classes marked with CrossSync.export_sync, the
classes will be processed according to the decorator arguments, and
a set of CrossSyncOutputFile objects will be returned for each output file.
If no CrossSync annotations are found, no changes will occur and an
empty set will be returned
"""
tree = ast.parse(open(self.in_path).read())
self._artifact_dict = {f.file_path: f for f in artifacts or []}
self.imports = self._get_imports(tree)
self.visit(tree)
# return set of new artifacts
return set(self._artifact_dict.values()).difference(artifacts or [])
@staticmethod
def _find_cs_output(node):
for i, n in enumerate(node.body):
if isinstance(n, ast.Assign):
for target in n.targets:
if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__":
# keep the output path
# remove the statement
node.body.pop(i)
return n.value.value + ".py"

def visit_Module(self, node):
# look for __CROSS_SYNC_OUTPUT__ Assign statement
self.output_path = self._find_cs_output(node)
if self.output_path:
# if found, process the file
return self.generic_visit(node)
else:
# not cross_sync file. Return None
return None

def visit_ClassDef(self, node):
"""
Called for each class in file. If class has a CrossSync decorator, it will be transformed
according to the decorator arguments. Otherwise, no changes will occur
Uses a set of CrossSyncOutputFile objects to store the transformed classes
and avoid duplicate writes
according to the decorator arguments. Otherwise, class is returned unchanged
"""
try:
converted = None
for decorator in node.decorator_list:
try:
handler = AstDecorator.get_for_node(decorator)
if isinstance(handler, ExportSync):
# find the path to write the sync class to
out_file = "/".join(handler.path.rsplit(".")[:-1]) + ".py"
sync_cls_name = handler.path.rsplit(".", 1)[-1]
# find the artifact file for the save location
output_artifact = self._artifact_dict.get(
out_file, CrossSyncOutputFile(out_file)
)
# write converted class details if not already present
if sync_cls_name not in output_artifact.contained_classes:
# transformation is handled in sync_ast_transform method of the decorator
converted = handler.sync_ast_transform(node, globals())
output_artifact.converted_classes.append(converted)
# handle file-level mypy ignores
mypy_ignores = [
s
for s in handler.mypy_ignore
if s not in output_artifact.mypy_ignore
]
output_artifact.mypy_ignore.extend(mypy_ignores)
# handle file-level imports
if not output_artifact.imports and handler.include_file_imports:
output_artifact.imports = self.imports
self._artifact_dict[out_file] = output_artifact
except ValueError:
continue
return converted
except ValueError as e:
raise ValueError(f"failed for class: {node.name}") from e
for decorator in node.decorator_list:
try:
handler = AstDecorator.get_for_node(decorator)
if isinstance(handler, ExportSync):
# transformation is handled in sync_ast_transform method of the decorator
return handler.sync_ast_transform(node, globals())
except ValueError:
# not cross_sync decorator
continue
# cross_sync decorator not found
return node

def _get_imports(
self, tree: ast.Module
) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]:
def visit_If(self, node):
"""
Grab the imports from the top of the file
raw imports, as well as try and if statements at the top level are included
remove CrossSync.is_async branches from top-level if statements
"""
imports = []
for node in tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom, ast.Try, ast.If)):
imports.append(self.cross_sync_symbol_transformer.visit(node))
return imports

if isinstance(node.test, ast.Attribute) and isinstance(node.test.value, ast.Name) and node.test.value.id == "CrossSync" and node.test.attr == "is_async":
return node.orelse
return self.generic_visit(node)

22 changes: 6 additions & 16 deletions google/cloud/bigtable/data/_sync/cross_sync/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

if TYPE_CHECKING:
import ast
from typing import Sequence, Callable, Any
from typing import Callable, Any


class AstDecorator:
Expand Down Expand Up @@ -175,25 +175,21 @@ class ExportSync(AstDecorator):
Class decorator for marking async classes to be converted to sync classes
Args:
path: path to output the generated sync class
sync_name: use a new name for the sync class
replace_symbols: a dict of symbols and replacements to use when generating sync class
docstring_format_vars: a dict of variables to replace in the docstring
mypy_ignore: set of mypy errors to ignore in the generated file
include_file_imports: if True, include top-level imports from the file in the generated sync class
add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync.<name>
"""

def __init__(
self,
path: str,
sync_name: str | None = None,
*,
replace_symbols: dict[str, str] | None = None,
docstring_format_vars: dict[str, tuple[str, str]] | None = None,
mypy_ignore: Sequence[str] = (),
include_file_imports: bool = True,
add_mapping_for_name: str | None = None,
):
self.path = path
self.sync_name = sync_name
self.replace_symbols = replace_symbols
docstring_format_vars = docstring_format_vars or {}
self.async_docstring_format_vars = {
Expand All @@ -202,8 +198,6 @@ def __init__(
self.sync_docstring_format_vars = {
k: v[1] for k, v in docstring_format_vars.items()
}
self.mypy_ignore = mypy_ignore
self.include_file_imports = include_file_imports
self.add_mapping_for_name = add_mapping_for_name

def async_decorator(self):
Expand All @@ -230,15 +224,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals):
import ast
import copy

if not self.path:
raise ValueError(
f"{wrapped_node.name} has no path specified in export_sync decorator"
)
# copy wrapped node
wrapped_node = copy.deepcopy(wrapped_node)
# update name
sync_cls_name = self.path.rsplit(".", 1)[-1]
wrapped_node.name = sync_cls_name
if self.sync_name:
wrapped_node.name = self.sync_name
# strip CrossSync decorators
if hasattr(wrapped_node, "decorator_list"):
wrapped_node.decorator_list = [
Expand Down
Loading

0 comments on commit 7dc9a2b

Please sign in to comment.