Skip to content

Commit

Permalink
fixed issues in generation
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 11, 2024
1 parent 7dc9a2b commit 41576f7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
17 changes: 9 additions & 8 deletions .cross_sync/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str:
"#\n"
"# This file is automatically generated by CrossSync. Do not edit manually.\n"
)
full_str = header + ast.unparse(self.converted)
full_str = header + ast.unparse(self.tree)
if with_black:
import black # type: ignore
import autoflake # type: ignore
Expand All @@ -65,8 +65,8 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str:
)
if save_to_disk:
import os
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
with open(self.output_path, "w") as f:
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
with open(self.file_path, "w") as f:
f.write(full_str)
return full_str

Expand All @@ -82,11 +82,12 @@ def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]:
file_transformer = CrossSyncFileHandler()
# run each file through ast transformation to find all annotated classes
for file_path in files:
file = open(file_path).read()
converted_tree = file_transformer.visit(ast.parse(file))
if converted_tree is not None:
ast_tree = ast.parse(open(file_path).read())
output_path = file_transformer.get_output_path(ast_tree)
if output_path is not None:
# contains __CROSS_SYNC_OUTPUT__ annotation
artifacts.add(CrossSyncOutputFile(file_path, converted_tree))
converted_tree = file_transformer.visit(ast_tree)
artifacts.add(CrossSyncOutputFile(output_path, converted_tree))
# return set of output artifacts
return artifacts

Expand All @@ -101,5 +102,5 @@ def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]):

search_root = sys.argv[1]
outputs = convert_files_in_dir(search_root)
print(f"Generated {len(outputs)} artifacts: {[a.file_name for a in outputs]}")
print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}")
save_artifacts(outputs)
10 changes: 4 additions & 6 deletions .cross_sync/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,18 @@ class CrossSyncFileHandler(ast.NodeTransformer):
"""

@staticmethod
def _find_cs_output(node):
def get_output_path(node):
for i, n in enumerate(node.body):
if isinstance(n, ast.Assign):
for target in n.targets:
if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__":
# keep the output path
# remove the statement
node.body.pop(i)
return n.value.value + ".py"
return n.value.value.replace(".", "/") + ".py"

def visit_Module(self, node):
# look for __CROSS_SYNC_OUTPUT__ Assign statement
self.output_path = self._find_cs_output(node)
if self.output_path:
output_path = self.get_output_path(node)
if output_path:
# if found, process the file
return self.generic_visit(node)
else:
Expand Down

0 comments on commit 41576f7

Please sign in to comment.