Skip to content

Commit

Permalink
Enable strict pyright check, fix all type hints, enable pyright check…
Browse files Browse the repository at this point in the history
… in CI
  • Loading branch information
emdoyle committed Feb 8, 2024
1 parent 38cf7b0 commit 7eb22c0
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 34 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ jobs:
coverage run --branch --source=../ -m pytest
coverage report
cd ..
- name: Check types with pyright
run: |
pyright .
- name: Check modguard
run: |
pip install .
Expand Down
2 changes: 2 additions & 0 deletions modguard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .boundary import Boundary
from .public import public

__all__ = ["Boundary", "public"]
2 changes: 1 addition & 1 deletion modguard/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def check(root: str, exclude_paths: Optional[list[str]] = None) -> list[ErrorInf

boundary_trie = build_boundary_trie(root, exclude_paths=exclude_paths)

errors = []
errors: list[ErrorInfo] = []
for file_path in utils.walk_pyfiles(root, exclude_paths=exclude_paths):
mod_path = utils.file_to_module_path(file_path)
nearest_boundary = boundary_trie.find_nearest(mod_path)
Expand Down
4 changes: 2 additions & 2 deletions modguard/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def print_invalid_exclude(path: str) -> None:
)


def parse_base_arguments(args) -> argparse.Namespace:
def parse_base_arguments(args: list[str]) -> argparse.Namespace:
base_parser = argparse.ArgumentParser(
prog="modguard",
add_help=True,
Expand All @@ -62,7 +62,7 @@ def parse_base_arguments(args) -> argparse.Namespace:
return base_parser.parse_args(args)


def parse_init_arguments(args) -> argparse.Namespace:
def parse_init_arguments(args: list[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="modguard init",
description="Initialize boundaries in a repository with modguard",
Expand Down
2 changes: 2 additions & 0 deletions modguard/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .boundary import BoundaryNode, BoundaryTrie
from .public import PublicMember

__all__ = ["BoundaryNode", "BoundaryTrie", "PublicMember"]
7 changes: 3 additions & 4 deletions modguard/parsing/boundary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ast
import os
from typing import Optional

from modguard.core.boundary import BoundaryTrie
Expand All @@ -14,7 +13,7 @@ def __init__(self):
self.is_modguard_boundary_imported = False
self.found_boundary = False

def visit_ImportFrom(self, node):
def visit_ImportFrom(self, node: ast.ImportFrom):
# Check if 'Boundary' is imported specifically from a 'modguard'-rooted module
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
Expand All @@ -25,14 +24,14 @@ def visit_ImportFrom(self, node):
self.is_modguard_boundary_imported = True
self.generic_visit(node)

def visit_Import(self, node):
def visit_Import(self, node: ast.Import):
# Check if 'modguard' is imported
for alias in node.names:
if alias.name == "modguard":
self.is_modguard_boundary_imported = True
self.generic_visit(node)

def visit_Call(self, node):
def visit_Call(self, node: ast.Call):
if self.is_modguard_boundary_imported:
if isinstance(node.func, ast.Attribute) and node.func.attr == "Boundary":
if (
Expand Down
6 changes: 3 additions & 3 deletions modguard/parsing/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class IgnoreDirective:


def get_ignore_directives(file_content: str) -> dict[int, IgnoreDirective]:
ignores = {}
ignores: dict[int, IgnoreDirective] = {}
lines = file_content.splitlines()
for lineno, line in enumerate(lines):
normal_lineno = lineno + 1
Expand Down Expand Up @@ -53,7 +53,7 @@ def _get_ignored_modules(self, lineno: int) -> Optional[list[str]]:
)
return directive.modules if directive else None

def visit_ImportFrom(self, node):
def visit_ImportFrom(self, node: ast.ImportFrom):
# For relative imports (level > 0), adjust the base module path
if node.module is not None and node.level > 0:
num_paths_to_strip = node.level - 1 if self.is_package else node.level
Expand Down Expand Up @@ -91,7 +91,7 @@ def visit_ImportFrom(self, node):

self.generic_visit(node)

def visit_Import(self, node):
def visit_Import(self, node: ast.Import):
ignored_modules = self._get_ignored_modules(node.lineno)
if ignored_modules is not None and len(ignored_modules) == 0:
# Empty ignore list signifies blanket ignore of following import
Expand Down
49 changes: 26 additions & 23 deletions modguard/parsing/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, module_name: str):
self.module_name = module_name
self.import_found = False

def visit_ImportFrom(self, node):
def visit_ImportFrom(self, node: ast.ImportFrom):
if self.module_name:
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
Expand All @@ -25,7 +25,7 @@ def visit_ImportFrom(self, node):
return
self.generic_visit(node)

def visit_Import(self, node):
def visit_Import(self, node: ast.Import):
for alias in node.names:
if alias.name == "modguard":
self.import_found = True
Expand All @@ -46,7 +46,7 @@ def __init__(self, current_mod_path: str, is_package: bool = False):
self.is_package = is_package
self.public_members: list[PublicMember] = []

def visit_ImportFrom(self, node):
def visit_ImportFrom(self, node: ast.ImportFrom):
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
)
Expand All @@ -56,7 +56,7 @@ def visit_ImportFrom(self, node):
self.is_modguard_public_imported = True
self.generic_visit(node)

def visit_Import(self, node):
def visit_Import(self, node: ast.Import):
for alias in node.names:
if alias.name == "modguard":
self.is_modguard_public_imported = True
Expand Down Expand Up @@ -107,7 +107,7 @@ def visit_ClassDef(self, node: ast.ClassDef):
self._add_public_member_from_decorator(node=node, decorator=decorator)
self.generic_visit(node)

def visit_Call(self, node):
def visit_Call(self, node: ast.Call):
parent_node = getattr(node, "parent")
grandparent_node = getattr(parent_node, "parent")
top_level = isinstance(parent_node, ast.Module)
Expand Down Expand Up @@ -188,36 +188,39 @@ def __init__(self, member_name: str):
self.matched_assignment = False
self.depth = 0

def _check_assignment(self, node):
if self.depth == 0:
for target in node.targets:
if isinstance(target, ast.Name) and target.id == self.member_name:
def _check_assignment_target(
self, target: Union[ast.expr, ast.Name, ast.Attribute, ast.Subscript]
):
if isinstance(target, ast.Name) and target.id == self.member_name:
self.matched_lineno = target.end_lineno
self.matched_assignment = True
return
elif isinstance(target, ast.List) or isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name) and elt.id == self.member_name:
self.matched_lineno = target.end_lineno
self.matched_assignment = True
return
elif isinstance(target, ast.List) or isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name) and elt.id == self.member_name:
self.matched_lineno = target.end_lineno
self.matched_assignment = True
return

def visit_Assign(self, node):
self._check_assignment(node)

def visit_Assign(self, node: ast.Assign):
if self.depth == 0:
for target in node.targets:
self._check_assignment_target(target)
self.generic_visit(node)

def visit_AnnAssign(self, node):
self._check_assignment(node)
def visit_AnnAssign(self, node: ast.AnnAssign):
if self.depth == 0:
self._check_assignment_target(node.target)
self.generic_visit(node)

def visit_Global(self, node):
def visit_Global(self, node: ast.Global):
if self.member_name in node.names:
self.matched_lineno = node.end_lineno
self.matched_assignment = True
return
self.generic_visit(node)

def visit_FunctionDef(self, node):
def visit_FunctionDef(self, node: ast.FunctionDef):
if self.depth == 0 and node.name == self.member_name:
self.matched_lineno = node.lineno
return
Expand All @@ -226,7 +229,7 @@ def visit_FunctionDef(self, node):
self.generic_visit(node)
self.depth -= 1

def visit_ClassDef(self, node):
def visit_ClassDef(self, node: ast.ClassDef):
if self.depth == 0 and node.name == self.member_name:
self.matched_lineno = node.lineno
return
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Issues = "https://github.com/never-over/modguard/issues"
[tool.pyright]
include = ["modguard"]
exclude = ["**/__pycache__"]
strict = ["modguard"]


[build-system]
Expand Down

0 comments on commit 7eb22c0

Please sign in to comment.