Skip to content

Commit

Permalink
Refactor split function with tests (jupyterlab#811)
Browse files Browse the repository at this point in the history
* Refactor split function with test

The split function was (1) selecting files in included directories in the top half of the function, and (2) selecting files with valid extensions and sharding them in the second half. This PR divides the split function in a new `collect_files` function that selects files with valid extensions from non-excluded directories, and then passes the valid filepaths into the `split` function, which calls `collect_files`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Test changed to use pytest

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor split function with test

The split function was (1) selecting files in included directories in the top half of the function, and (2) selecting files with valid extensions and sharding them in the second half. This PR divides the split function in a new `collect_files` function that selects files with valid extensions from non-excluded directories, and then passes the valid filepaths into the `split` function, which calls `collect_files`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Test changed to use pytest

* refactored tests for directory.py using pytest fixtures

Replaced testing using unittests with testing using pytest fixtures.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove old test files

replacd unittests with pytests

* Update test_directory.py

* Update docstrings and further improve code for retrieve filepaths and split

Further improvements to the code suggested from the review of PR

* update docstring in test file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update directory.py

Changed function level constant from all caps to lower case to line up with the convention in https://peps.python.org/pep-0008/#constants.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
srdas and pre-commit-ci[bot] authored Jun 5, 2024
1 parent ffffb15 commit f138b0d
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 9 deletions.
21 changes: 21 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pytest

pytest_plugins = ("jupyter_server.pytest_plugin",)
Expand All @@ -6,3 +8,22 @@
@pytest.fixture
def jp_server_config(jp_server_config):
return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": True}}}


@pytest.fixture(scope="session")
def static_test_files_dir() -> Path:
return (
Path(__file__).parent.resolve()
/ "packages"
/ "jupyter-ai"
/ "jupyter_ai"
/ "tests"
/ "static"
)


@pytest.fixture
def jp_ai_staging_dir(jp_data_dir: Path) -> Path:
staging_area = jp_data_dir / "scheduler_staging_area"
staging_area.mkdir()
return staging_area
25 changes: 16 additions & 9 deletions packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ def flatten(*chunk_lists):
return list(itertools.chain(*chunk_lists))


def split(path, all_files: bool, splitter):
chunks = []

def collect_filepaths(path, all_files: bool):
"""Selects eligible files, i.e.,
1. Files not in excluded directories, and
2. Files that are in the valid file extensions list
Called from the `split` function.
Returns all the filepaths to eligible files.
"""
# Check if the path points to a single file
if os.path.isfile(path):
filepaths = [Path(path)]
Expand All @@ -125,17 +129,20 @@ def split(path, all_files: bool, splitter):
d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS)
]
filenames = [f for f in filenames if not f[0] == "."]
filepaths += [Path(os.path.join(dir, filename)) for filename in filenames]
filepaths.extend([Path(dir) / filename for filename in filenames])
valid_exts = {j.lower() for j in SUPPORTED_EXTS}
filepaths = [fp for fp in filepaths if fp.suffix.lower() in valid_exts]
return filepaths

for filepath in filepaths:
# Lower case everything to make sure file extension comparisons are not case sensitive
if filepath.suffix.lower() not in {j.lower() for j in SUPPORTED_EXTS}:
continue

def split(path, all_files: bool, splitter):
"""Splits files into chunks for vector db in RAG"""
chunks = []
filepaths = collect_filepaths(path, all_files)
for filepath in filepaths:
document = dask.delayed(path_to_doc)(filepath)
chunk = dask.delayed(split_document)(document, splitter)
chunks.append(chunk)

flattened_chunks = dask.delayed(flatten)(*chunks)
return flattened_chunks

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hidden temp text file.
10 changes: 10 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file0.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<!DOCTYPE html>
<html>
<head><meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Notebook</title>
</head>
<body>
<div>This is the notebook content</div>
</body>
</html>
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is a temp test text file.
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

print("Hello World")
2 changes: 2 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file3.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Column1, Column2
Test1, test2
Empty file.
Binary file not shown.
56 changes: 56 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_directory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import shutil
from pathlib import Path
from typing import Tuple

import pytest
from jupyter_ai.document_loaders.directory import collect_filepaths


@pytest.fixture
def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path:
file1_path = static_test_files_dir / ".hidden_file.pdf"
file2_path = static_test_files_dir / ".hidden_file.txt"
file3_path = static_test_files_dir / "file0.html"
file4_path = static_test_files_dir / "file1.txt"
file5_path = static_test_files_dir / "file2.py"
file6_path = static_test_files_dir / "file3.csv"
file7_path = static_test_files_dir / "file3.xyz"
file8_path = static_test_files_dir / "file4.pdf"

job_staging_dir = jp_ai_staging_dir / "TestDir"
job_staging_dir.mkdir()
job_staging_subdir = job_staging_dir / "subdir"
job_staging_subdir.mkdir()
job_staging_hiddendir = job_staging_dir / ".hidden_dir"
job_staging_hiddendir.mkdir()

shutil.copy2(file1_path, job_staging_dir)
shutil.copy2(file2_path, job_staging_subdir)
shutil.copy2(file3_path, job_staging_dir)
shutil.copy2(file4_path, job_staging_subdir)
shutil.copy2(file5_path, job_staging_subdir)
shutil.copy2(file6_path, job_staging_hiddendir)
shutil.copy2(file7_path, job_staging_subdir)
shutil.copy2(file8_path, job_staging_hiddendir)

return job_staging_dir


def test_collect_filepaths(staging_dir):
"""
Test that the number of valid files for `/learn` is correct.
i.e., the `collect_filepaths` function only selects files that are
1. Not in the the excluded directories and
2. Are in the valid file extensions list.
"""
all_files = False
staging_dir_filepath = staging_dir
# Call the function we want to test
result = collect_filepaths(staging_dir_filepath, all_files)

assert len(result) == 3 # Test number of valid files

filenames = [fp.name for fp in result]
assert "file0.html" in filenames # Check that valid file is included
assert "file3.xyz" not in filenames # Check that invalid file is excluded

0 comments on commit f138b0d

Please sign in to comment.