diff --git a/rewrite/rewrite/python/format/auto_format.py b/rewrite/rewrite/python/format/auto_format.py index f9fdb96..b89dc8d 100644 --- a/rewrite/rewrite/python/format/auto_format.py +++ b/rewrite/rewrite/python/format/auto_format.py @@ -1,7 +1,7 @@ from typing import Optional from rewrite import Recipe, Tree, Cursor -from rewrite.java import JavaSourceFile +from rewrite.java import JavaSourceFile, MethodDeclaration, J, Space from rewrite.python import PythonVisitor, SpacesStyle, IntelliJ from rewrite.visitor import P, T @@ -27,3 +27,8 @@ class SpacesVisitor(PythonVisitor): def __init__(self, style: SpacesStyle, stop_after: Tree = None): self._style = style self._stop_after = stop_after + + def visit_method_declaration(self, method_declaration: MethodDeclaration, p: P) -> J: + return method_declaration.padding.with_parameters( + method_declaration.padding.parameters.with_before(Space.SINGLE_SPACE if self._style.beforeParentheses.method_parentheses else Space.EMPTY) + ) diff --git a/rewrite/rewrite/python/parser.py b/rewrite/rewrite/python/parser.py index e9b1654..3ec674d 100644 --- a/rewrite/rewrite/python/parser.py +++ b/rewrite/rewrite/python/parser.py @@ -1,9 +1,10 @@ import ast import logging +from dataclasses import dataclass from pathlib import Path from typing import Iterable, Optional -from rewrite import Parser, ParserInput, ExecutionContext, SourceFile, ParseError +from rewrite import Parser, ParserInput, ExecutionContext, SourceFile, ParseError, NamedStyles, Markers, Tree, random_id from rewrite.parser import require_print_equals_input, ParserBuilder from ._parser_visitor import ParserVisitor from .tree import CompilationUnit @@ -11,7 +12,10 @@ logging.basicConfig(level=logging.ERROR) +@dataclass(frozen=True) class PythonParser(Parser): + _styles: Optional[Iterable[NamedStyles]] + def parse_inputs(self, sources: Iterable[ParserInput], relative_to: Optional[Path], ctx: ExecutionContext) -> Iterable[SourceFile]: accepted = (source for source in sources if self.accept(source.path)) @@ -20,6 +24,7 @@ def parse_inputs(self, sources: Iterable[ParserInput], relative_to: Optional[Pat source_str = source.source().read() tree = ast.parse(source_str, source.path) cu = ParserVisitor(source_str).visit(tree).with_source_path(source.path) + cu = cu.with_markers(Markers.build(random_id(), self._styles)) if self._styles else cu cu = require_print_equals_input(self, cu, source, relative_to, ctx) except Exception as e: logging.error(f"An error was encountered while parsing {source.path}: {str(e)}", exc_info=True) @@ -37,6 +42,11 @@ class PythonParserBuilder(ParserBuilder): def __init__(self): self._source_file_type = type(CompilationUnit) self._dsl_name = 'python' + self._styles = None + + def styles(self, styles: Iterable[NamedStyles]): + self._styles = styles + return self def build(self) -> Parser: - return PythonParser() + return PythonParser(self._styles) diff --git a/rewrite/rewrite/visitor.py b/rewrite/rewrite/visitor.py index ad87c28..169fe65 100644 --- a/rewrite/rewrite/visitor.py +++ b/rewrite/rewrite/visitor.py @@ -40,7 +40,7 @@ def first_enclosing(self, type: Type[P]) -> P: return None def fork(self) -> Cursor: - return Cursor(self.parent.fork(), self.value) + return Cursor(self.parent.fork(), self.value) if self.parent else self class TreeVisitor(Protocol[T, P]): diff --git a/rewrite/tests/python/all/format/demo_test.py b/rewrite/tests/python/all/format/demo_test.py index 10307ac..786e66c 100644 --- a/rewrite/tests/python/all/format/demo_test.py +++ b/rewrite/tests/python/all/format/demo_test.py @@ -1,7 +1,7 @@ from typing import Optional from rewrite.java import Space, P -from rewrite.python import PythonVisitor +from rewrite.python import PythonVisitor, AutoFormat from rewrite.test import rewrite_run, python, from_visitor @@ -18,6 +18,25 @@ def getter(self, row): recipe=from_visitor(NoSpaces()) ) + +def test_spaces_before_method_parentheses(): + rewrite_run( + # language=python + python( + """ + class Foo: + def getter (self, row): + pass + """, + """ + class Foo: + def getter(self, row): + pass + """ + ), + recipe=AutoFormat() + ) + class NoSpaces(PythonVisitor): def visit_space(self, space: Optional[Space], loc: Optional[Space.Location], p: P) -> Optional[Space]: return Space.EMPTY if space else None