diff --git a/conftest.py b/conftest.py index dc3b0e05c..099bc3f8e 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest pytest_plugins = ("jupyter_server.pytest_plugin",) @@ -6,3 +8,22 @@ @pytest.fixture def jp_server_config(jp_server_config): return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": True}}} + + +@pytest.fixture(scope="session") +def static_test_files_dir() -> Path: + return ( + Path(__file__).parent.resolve() + / "packages" + / "jupyter-ai" + / "jupyter_ai" + / "tests" + / "static" + ) + + +@pytest.fixture +def jp_ai_staging_dir(jp_data_dir: Path) -> Path: + staging_area = jp_data_dir / "scheduler_staging_area" + staging_area.mkdir() + return staging_area diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index e493fb385..7b9b28328 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -109,9 +109,13 @@ def flatten(*chunk_lists): return list(itertools.chain(*chunk_lists)) -def split(path, all_files: bool, splitter): - chunks = [] - +def collect_filepaths(path, all_files: bool): + """Selects eligible files, i.e., + 1. Files not in excluded directories, and + 2. Files that are in the valid file extensions list + Called from the `split` function. + Returns all the filepaths to eligible files. + """ # Check if the path points to a single file if os.path.isfile(path): filepaths = [Path(path)] @@ -125,17 +129,20 @@ def split(path, all_files: bool, splitter): d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS) ] filenames = [f for f in filenames if not f[0] == "."] - filepaths += [Path(os.path.join(dir, filename)) for filename in filenames] + filepaths.extend([Path(dir) / filename for filename in filenames]) + valid_exts = {j.lower() for j in SUPPORTED_EXTS} + filepaths = [fp for fp in filepaths if fp.suffix.lower() in valid_exts] + return filepaths - for filepath in filepaths: - # Lower case everything to make sure file extension comparisons are not case sensitive - if filepath.suffix.lower() not in {j.lower() for j in SUPPORTED_EXTS}: - continue +def split(path, all_files: bool, splitter): + """Splits files into chunks for vector db in RAG""" + chunks = [] + filepaths = collect_filepaths(path, all_files) + for filepath in filepaths: document = dask.delayed(path_to_doc)(filepath) chunk = dask.delayed(split_document)(document, splitter) chunks.append(chunk) - flattened_chunks = dask.delayed(flatten)(*chunks) return flattened_chunks diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.pdf b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.pdf new file mode 100644 index 000000000..91cc2dec1 Binary files /dev/null and b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.pdf differ diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.txt b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.txt new file mode 100644 index 000000000..760cc808c --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.txt @@ -0,0 +1 @@ +Hidden temp text file. diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file0.html b/packages/jupyter-ai/jupyter_ai/tests/static/file0.html new file mode 100644 index 000000000..355180f0a --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/static/file0.html @@ -0,0 +1,10 @@ + + + + +Notebook + + +
This is the notebook content
+ + diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file1.txt b/packages/jupyter-ai/jupyter_ai/tests/static/file1.txt new file mode 100644 index 000000000..008cb927a --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/static/file1.txt @@ -0,0 +1 @@ +This is a temp test text file. diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file2.py b/packages/jupyter-ai/jupyter_ai/tests/static/file2.py new file mode 100644 index 000000000..746e17199 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/static/file2.py @@ -0,0 +1,3 @@ +import os + +print("Hello World") diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file3.csv b/packages/jupyter-ai/jupyter_ai/tests/static/file3.csv new file mode 100644 index 000000000..01fba4acc --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/static/file3.csv @@ -0,0 +1,2 @@ +Column1, Column2 +Test1, test2 diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file3.xyz b/packages/jupyter-ai/jupyter_ai/tests/static/file3.xyz new file mode 100644 index 000000000..e69de29bb diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file4.pdf b/packages/jupyter-ai/jupyter_ai/tests/static/file4.pdf new file mode 100644 index 000000000..1a35f789d Binary files /dev/null and b/packages/jupyter-ai/jupyter_ai/tests/static/file4.pdf differ diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_directory.py b/packages/jupyter-ai/jupyter_ai/tests/test_directory.py new file mode 100644 index 000000000..f9432b90d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_directory.py @@ -0,0 +1,56 @@ +import os +import shutil +from pathlib import Path +from typing import Tuple + +import pytest +from jupyter_ai.document_loaders.directory import collect_filepaths + + +@pytest.fixture +def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path: + file1_path = static_test_files_dir / ".hidden_file.pdf" + file2_path = static_test_files_dir / ".hidden_file.txt" + file3_path = static_test_files_dir / "file0.html" + file4_path = static_test_files_dir / "file1.txt" + file5_path = static_test_files_dir / "file2.py" + file6_path = static_test_files_dir / "file3.csv" + file7_path = static_test_files_dir / "file3.xyz" + file8_path = static_test_files_dir / "file4.pdf" + + job_staging_dir = jp_ai_staging_dir / "TestDir" + job_staging_dir.mkdir() + job_staging_subdir = job_staging_dir / "subdir" + job_staging_subdir.mkdir() + job_staging_hiddendir = job_staging_dir / ".hidden_dir" + job_staging_hiddendir.mkdir() + + shutil.copy2(file1_path, job_staging_dir) + shutil.copy2(file2_path, job_staging_subdir) + shutil.copy2(file3_path, job_staging_dir) + shutil.copy2(file4_path, job_staging_subdir) + shutil.copy2(file5_path, job_staging_subdir) + shutil.copy2(file6_path, job_staging_hiddendir) + shutil.copy2(file7_path, job_staging_subdir) + shutil.copy2(file8_path, job_staging_hiddendir) + + return job_staging_dir + + +def test_collect_filepaths(staging_dir): + """ + Test that the number of valid files for `/learn` is correct. + i.e., the `collect_filepaths` function only selects files that are + 1. Not in the the excluded directories and + 2. Are in the valid file extensions list. + """ + all_files = False + staging_dir_filepath = staging_dir + # Call the function we want to test + result = collect_filepaths(staging_dir_filepath, all_files) + + assert len(result) == 3 # Test number of valid files + + filenames = [fp.name for fp in result] + assert "file0.html" in filenames # Check that valid file is included + assert "file3.xyz" not in filenames # Check that invalid file is excluded