From 41576f756d1a5eb9d6d242c8d0e2810c0a06e368 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 18:06:42 -0700 Subject: [PATCH] fixed issues in generation --- .cross_sync/generate.py | 17 +++++++++-------- .cross_sync/transformers.py | 10 ++++------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index fea4f04f3..5c130079f 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -54,7 +54,7 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: "#\n" "# This file is automatically generated by CrossSync. Do not edit manually.\n" ) - full_str = header + ast.unparse(self.converted) + full_str = header + ast.unparse(self.tree) if with_black: import black # type: ignore import autoflake # type: ignore @@ -65,8 +65,8 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: ) if save_to_disk: import os - os.makedirs(os.path.dirname(self.output_path), exist_ok=True) - with open(self.output_path, "w") as f: + os.makedirs(os.path.dirname(self.file_path), exist_ok=True) + with open(self.file_path, "w") as f: f.write(full_str) return full_str @@ -82,11 +82,12 @@ def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: file_transformer = CrossSyncFileHandler() # run each file through ast transformation to find all annotated classes for file_path in files: - file = open(file_path).read() - converted_tree = file_transformer.visit(ast.parse(file)) - if converted_tree is not None: + ast_tree = ast.parse(open(file_path).read()) + output_path = file_transformer.get_output_path(ast_tree) + if output_path is not None: # contains __CROSS_SYNC_OUTPUT__ annotation - artifacts.add(CrossSyncOutputFile(file_path, converted_tree)) + converted_tree = file_transformer.visit(ast_tree) + artifacts.add(CrossSyncOutputFile(output_path, converted_tree)) # return set of output artifacts return artifacts @@ -101,5 +102,5 @@ def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): search_root = sys.argv[1] outputs = convert_files_in_dir(search_root) - print(f"Generated {len(outputs)} artifacts: {[a.file_name for a in outputs]}") + print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}") save_artifacts(outputs) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 5afef0d41..1744e9da0 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -232,20 +232,18 @@ class CrossSyncFileHandler(ast.NodeTransformer): """ @staticmethod - def _find_cs_output(node): + def get_output_path(node): for i, n in enumerate(node.body): if isinstance(n, ast.Assign): for target in n.targets: if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__": # keep the output path - # remove the statement - node.body.pop(i) - return n.value.value + ".py" + return n.value.value.replace(".", "/") + ".py" def visit_Module(self, node): # look for __CROSS_SYNC_OUTPUT__ Assign statement - self.output_path = self._find_cs_output(node) - if self.output_path: + output_path = self.get_output_path(node) + if output_path: # if found, process the file return self.generic_visit(node) else: