Skip to content

Commit

Permalink
Add NamedStyles and some other base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Nov 28, 2024
1 parent 3a8818a commit 6cfab68
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 8 deletions.
1 change: 1 addition & 0 deletions rewrite/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@

# Style
'Style',
'NamedStyles',
]
3 changes: 3 additions & 0 deletions rewrite/rewrite/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def find_first(self, type: type):
return marker
return None

def find_all(self, type: type):
return [m for m in self.markers if isinstance(m, type)]

EMPTY: ClassVar[Markers]

def __eq__(self, other: object) -> bool:
Expand Down
1 change: 1 addition & 0 deletions rewrite/rewrite/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
# Style
'PythonStyle',
'SpacesStyle',
'IntelliJ',

# Formatter
'AutoFormat',
Expand Down
15 changes: 13 additions & 2 deletions rewrite/rewrite/python/format/auto_format.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional

from rewrite import Recipe, Tree, Cursor
from rewrite.python import PythonVisitor
from rewrite.java import JavaSourceFile
from rewrite.python import PythonVisitor, SpacesStyle, IntelliJ
from rewrite.visitor import P, T


Expand All @@ -15,4 +16,14 @@ def __init__(self, stop_after: Tree = None):
self._stop_after = stop_after

def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> Optional[T]:
pass
self._cursor = parent if parent is not None else Cursor(None, Cursor.ROOT_VALUE)
cu = tree if isinstance(tree, JavaSourceFile) else self._cursor.first_enclosing_or_throw(JavaSourceFile)

tree = SpacesVisitor(cu.get_style(SpacesStyle) or IntelliJ.spaces(), self._stop_after).visit(tree, p, self._cursor.fork())
return tree


class SpacesVisitor(PythonVisitor):
def __init__(self, style: SpacesStyle, stop_after: Tree = None):
self._style = style
self._stop_after = stop_after
45 changes: 43 additions & 2 deletions rewrite/rewrite/python/style.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from ..style import Style
from ..style import Style, NamedStyles


class PythonStyle(Style):
Expand Down Expand Up @@ -41,7 +41,6 @@ class Within:
class Other:
before_comma: bool
after_comma: bool
before_semicolon: bool
before_for_semicolon: bool
before_colon: bool
after_colon: bool
Expand All @@ -53,3 +52,45 @@ class Other:
aroundOperators: AroundOperators
within: Within
other: Other


class IntelliJ(NamedStyles):
@classmethod
def spaces(cls) -> SpacesStyle:
return SpacesStyle(
SpacesStyle.BeforeParentheses(
method_call_parentheses=False,
method_parentheses=False,
left_bracket=False,
),
SpacesStyle.AroundOperators(
assignment=True,
equality=True,
relational=True,
bitwise=True,
additive=True,
multiplicative=True,
shift=True,
power=True,
eq_in_named_parameter=False,
eq_in_keyword_argument=False,
),
SpacesStyle.Within(
brackets=False,
method_parentheses=False,
empty_method_parentheses=False,
method_call_parentheses=False,
empty_method_call_parentheses=False,
braces=False,
),
SpacesStyle.Other(
before_comma=False,
after_comma=True,
before_for_semicolon=False,
before_colon=False,
after_colon=True,
before_backslash=True,
before_hash=True,
after_hash=True,
),
)
36 changes: 34 additions & 2 deletions rewrite/rewrite/style.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
from __future__ import annotations

from typing import Protocol
from dataclasses import dataclass
from typing import Protocol, TypeVar, Type, Iterable, Optional, Set
from uuid import UUID

from rewrite import Marker


class Style(Protocol):
def merge(self, lower_precedence: Style) -> Style:
...
return self

def apply_defaults(self) -> Style:
return self


S = TypeVar('S', bound=Style)


@dataclass(frozen=True)
class NamedStyles(Marker):
_id: UUID
_name: str
_display_name: str
_description: Optional[str]
_tags: Set[str]
_styles: Iterable[Style]

@classmethod
def merge(cls, style_type: Type[S], named_styles: Iterable[NamedStyles]) -> Optional[S]:
merged = None
for named_style in named_styles:
styles = named_style._styles
if styles is not None:
for style in styles:
if isinstance(style, style_type):
style = style.apply_defaults()
if merged is None:
merged = style
else:
merged = merged.merge(style)
return merged
11 changes: 9 additions & 2 deletions rewrite/rewrite/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Protocol, Optional, Any, TypeVar, runtime_checkable, cast, TYPE_CHECKING, Generic, ClassVar, Callable
from typing import Protocol, Optional, Any, TypeVar, runtime_checkable, cast, TYPE_CHECKING, Generic, ClassVar, \
Callable, Type
from uuid import UUID

from rewrite import Markers
from .markers import Markers
from .style import NamedStyles, Style

if TYPE_CHECKING:
from rewrite import TreeVisitor, ExecutionContext
Expand Down Expand Up @@ -72,6 +74,8 @@ def create_printer(self, cursor: Cursor) -> TreeVisitor[Any, PrintOutputCapture[
...


S = TypeVar('S', bound=Style)

@runtime_checkable
class SourceFile(Tree, Protocol):
@property
Expand Down Expand Up @@ -100,6 +104,9 @@ def print_equals_input(self, input: 'ParserInput', ctx: ExecutionContext) -> boo
printed = self.print_all()
return printed == input.source().read()

def get_style(self, style: Type[S]) -> Optional[S]:
return NamedStyles.merge(style, self.markers.find_all(NamedStyles))


@dataclass(frozen=True)
class FileAttributes:
Expand Down
3 changes: 3 additions & 0 deletions rewrite/rewrite/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def first_enclosing(self, type: Type[P]) -> P:
c = c.parent
return None

def fork(self) -> Cursor:
return Cursor(self.parent.fork(), self.value)


class TreeVisitor(Protocol[T, P]):
_visit_count: int = 0
Expand Down

0 comments on commit 6cfab68

Please sign in to comment.