diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 38dff42e8..a5bb49f51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,13 +33,14 @@ repos: (?x)^( tests/.*| # TODO: gradually enable - narwhals/series\.py$| + narwhals/series\.py| # TODO: gradually enable - narwhals/dataframe\.py$| + narwhals/dataframe\.py| # TODO: gradually enable - narwhals/dependencies\.py$| + narwhals/dependencies\.py| # private, so less urgent to document too well - narwhals/_.* + narwhals/_.*| + ^utils/.* )$ - repo: local hooks: @@ -55,6 +56,11 @@ repos: language: python files: ^narwhals/ exclude: ^narwhals/dependencies\.py + - id: check-docstrings-execute + name: check docstrings execute + entry: python utils/check_docstrings.py + language: python + files: ^narwhals/ - repo: https://github.com/kynan/nbstripout rev: 0.8.0 hooks: diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index d2249d7f9..d83cacb62 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -1773,7 +1773,7 @@ def narwhalify( Instead of writing >>> import narwhals as nw - >>> def func(df): + >>> def agnostic_group_by_sum(df): ... df = nw.from_native(df, pass_through=True) ... df = df.group_by("a").agg(nw.col("b").sum()) ... return nw.to_native(df) @@ -1781,7 +1781,7 @@ def narwhalify( you can just write >>> @nw.narwhalify - ... def func(df): + ... def agnostic_group_by_sum(df): ... return df.group_by("a").agg(nw.col("b").sum()) """ pass_through = validate_strict_and_pass_though( diff --git a/narwhals/translate.py b/narwhals/translate.py index 15203a949..8e3c6bc47 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -816,7 +816,7 @@ def narwhalify( Instead of writing >>> import narwhals as nw - >>> def func(df): + >>> def agnostic_group_by_sum(df): ... df = nw.from_native(df, pass_through=True) ... df = df.group_by("a").agg(nw.col("b").sum()) ... return nw.to_native(df) @@ -824,7 +824,7 @@ def narwhalify( you can just write >>> @nw.narwhalify - ... def func(df): + ... def agnostic_group_by_sum(df): ... return df.group_by("a").agg(nw.col("b").sum()) """ from narwhals.utils import validate_strict_and_pass_though diff --git a/requirements-dev.txt b/requirements-dev.txt index 0df99e50d..652da32d8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,6 +9,7 @@ pytest pytest-cov pytest-randomly pytest-env +ruff hypothesis hypothesis[numpy] scikit-learn diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py new file mode 100644 index 000000000..39cf99443 --- /dev/null +++ b/utils/check_docstrings.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import ast +import doctest +import subprocess +import sys +import tempfile +from pathlib import Path + + +def extract_docstring_examples(files: list[Path]) -> list[tuple[Path, str, str]]: + """Extract examples from docstrings in Python files.""" + examples: list[tuple[Path, str, str]] = [] + + for file in files: + with open(file, encoding="utf-8") as f: + tree = ast.parse(f.read()) + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef)): + docstring = ast.get_docstring(node) + if docstring: + parsed_examples = doctest.DocTestParser().get_examples(docstring) + example_code = "\n".join( + example.source for example in parsed_examples + ) + if example_code.strip(): + examples.append((file, node.name, example_code)) + + return examples + + +def create_temp_files(examples: list[tuple[Path, str, str]]) -> list[tuple[Path, str]]: + """Create temporary files for all examples and return their paths.""" + temp_files: list[tuple[Path, str]] = [] + + for file, name, example in examples: + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) # noqa: SIM115 + temp_file.write(example) + temp_file_path = temp_file.name + temp_file.close() + temp_files.append((Path(temp_file_path), f"{file}:{name}")) + + return temp_files + + +def run_ruff_on_temp_files(temp_files: list[tuple[Path, str]]) -> list[str]: + """Run ruff on all temporary files and collect error messages.""" + temp_file_paths = [str(temp_file[0]) for temp_file in temp_files] + + result = subprocess.run( # noqa: S603 + [ # noqa: S607 + "python", + "-m", + "ruff", + "check", + "--select=F", + "--ignore=F811", + *temp_file_paths, + ], + capture_output=True, + text=True, + check=False, + ) + + if result.returncode == 0: + return [] # No issues found + return result.stdout.splitlines() # Return ruff errors as a list of lines + + +def report_errors(errors: list[str], temp_files: list[tuple[Path, str]]) -> None: + """Map errors back to original examples and report them.""" + if not errors: + return + + print("❌ Ruff issues found in examples:\n") # noqa: T201 + for line in errors: + for temp_file, original_context in temp_files: + if str(temp_file) in line: + print(f"{original_context}{line.replace(str(temp_file), '')}") # noqa: T201 + break + + +def cleanup_temp_files(temp_files: list[tuple[Path, str]]) -> None: + """Remove all temporary files.""" + for temp_file, _ in temp_files: + temp_file.unlink() + + +def main(python_files: list[str]) -> None: + docstring_examples = extract_docstring_examples(python_files) + + if not docstring_examples: + sys.exit(0) + + temp_files = create_temp_files(docstring_examples) + + try: + errors = run_ruff_on_temp_files(temp_files) + report_errors(errors, temp_files) + finally: + cleanup_temp_files(temp_files) + + if errors: + sys.exit(1) + sys.exit(0) + + +if __name__ == "__main__": + main(sys.argv[1:])