diff --git a/conftest.py b/conftest.py
index dc3b0e05c..099bc3f8e 100644
--- a/conftest.py
+++ b/conftest.py
@@ -1,3 +1,5 @@
+from pathlib import Path
+
import pytest
pytest_plugins = ("jupyter_server.pytest_plugin",)
@@ -6,3 +8,22 @@
@pytest.fixture
def jp_server_config(jp_server_config):
return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": True}}}
+
+
+@pytest.fixture(scope="session")
+def static_test_files_dir() -> Path:
+ return (
+ Path(__file__).parent.resolve()
+ / "packages"
+ / "jupyter-ai"
+ / "jupyter_ai"
+ / "tests"
+ / "static"
+ )
+
+
+@pytest.fixture
+def jp_ai_staging_dir(jp_data_dir: Path) -> Path:
+ staging_area = jp_data_dir / "scheduler_staging_area"
+ staging_area.mkdir()
+ return staging_area
diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
index e493fb385..7b9b28328 100644
--- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
+++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
@@ -109,9 +109,13 @@ def flatten(*chunk_lists):
return list(itertools.chain(*chunk_lists))
-def split(path, all_files: bool, splitter):
- chunks = []
-
+def collect_filepaths(path, all_files: bool):
+ """Selects eligible files, i.e.,
+ 1. Files not in excluded directories, and
+ 2. Files that are in the valid file extensions list
+ Called from the `split` function.
+ Returns all the filepaths to eligible files.
+ """
# Check if the path points to a single file
if os.path.isfile(path):
filepaths = [Path(path)]
@@ -125,17 +129,20 @@ def split(path, all_files: bool, splitter):
d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS)
]
filenames = [f for f in filenames if not f[0] == "."]
- filepaths += [Path(os.path.join(dir, filename)) for filename in filenames]
+ filepaths.extend([Path(dir) / filename for filename in filenames])
+ valid_exts = {j.lower() for j in SUPPORTED_EXTS}
+ filepaths = [fp for fp in filepaths if fp.suffix.lower() in valid_exts]
+ return filepaths
- for filepath in filepaths:
- # Lower case everything to make sure file extension comparisons are not case sensitive
- if filepath.suffix.lower() not in {j.lower() for j in SUPPORTED_EXTS}:
- continue
+def split(path, all_files: bool, splitter):
+ """Splits files into chunks for vector db in RAG"""
+ chunks = []
+ filepaths = collect_filepaths(path, all_files)
+ for filepath in filepaths:
document = dask.delayed(path_to_doc)(filepath)
chunk = dask.delayed(split_document)(document, splitter)
chunks.append(chunk)
-
flattened_chunks = dask.delayed(flatten)(*chunks)
return flattened_chunks
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.pdf b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.pdf
new file mode 100644
index 000000000..91cc2dec1
Binary files /dev/null and b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.pdf differ
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.txt b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.txt
new file mode 100644
index 000000000..760cc808c
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/tests/static/.hidden_file.txt
@@ -0,0 +1 @@
+Hidden temp text file.
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file0.html b/packages/jupyter-ai/jupyter_ai/tests/static/file0.html
new file mode 100644
index 000000000..355180f0a
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/tests/static/file0.html
@@ -0,0 +1,10 @@
+
+
+
+
+Notebook
+
+
+ This is the notebook content
+
+
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file1.txt b/packages/jupyter-ai/jupyter_ai/tests/static/file1.txt
new file mode 100644
index 000000000..008cb927a
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/tests/static/file1.txt
@@ -0,0 +1 @@
+This is a temp test text file.
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file2.py b/packages/jupyter-ai/jupyter_ai/tests/static/file2.py
new file mode 100644
index 000000000..746e17199
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/tests/static/file2.py
@@ -0,0 +1,3 @@
+import os
+
+print("Hello World")
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file3.csv b/packages/jupyter-ai/jupyter_ai/tests/static/file3.csv
new file mode 100644
index 000000000..01fba4acc
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/tests/static/file3.csv
@@ -0,0 +1,2 @@
+Column1, Column2
+Test1, test2
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file3.xyz b/packages/jupyter-ai/jupyter_ai/tests/static/file3.xyz
new file mode 100644
index 000000000..e69de29bb
diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file4.pdf b/packages/jupyter-ai/jupyter_ai/tests/static/file4.pdf
new file mode 100644
index 000000000..1a35f789d
Binary files /dev/null and b/packages/jupyter-ai/jupyter_ai/tests/static/file4.pdf differ
diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_directory.py b/packages/jupyter-ai/jupyter_ai/tests/test_directory.py
new file mode 100644
index 000000000..f9432b90d
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/tests/test_directory.py
@@ -0,0 +1,56 @@
+import os
+import shutil
+from pathlib import Path
+from typing import Tuple
+
+import pytest
+from jupyter_ai.document_loaders.directory import collect_filepaths
+
+
+@pytest.fixture
+def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path:
+ file1_path = static_test_files_dir / ".hidden_file.pdf"
+ file2_path = static_test_files_dir / ".hidden_file.txt"
+ file3_path = static_test_files_dir / "file0.html"
+ file4_path = static_test_files_dir / "file1.txt"
+ file5_path = static_test_files_dir / "file2.py"
+ file6_path = static_test_files_dir / "file3.csv"
+ file7_path = static_test_files_dir / "file3.xyz"
+ file8_path = static_test_files_dir / "file4.pdf"
+
+ job_staging_dir = jp_ai_staging_dir / "TestDir"
+ job_staging_dir.mkdir()
+ job_staging_subdir = job_staging_dir / "subdir"
+ job_staging_subdir.mkdir()
+ job_staging_hiddendir = job_staging_dir / ".hidden_dir"
+ job_staging_hiddendir.mkdir()
+
+ shutil.copy2(file1_path, job_staging_dir)
+ shutil.copy2(file2_path, job_staging_subdir)
+ shutil.copy2(file3_path, job_staging_dir)
+ shutil.copy2(file4_path, job_staging_subdir)
+ shutil.copy2(file5_path, job_staging_subdir)
+ shutil.copy2(file6_path, job_staging_hiddendir)
+ shutil.copy2(file7_path, job_staging_subdir)
+ shutil.copy2(file8_path, job_staging_hiddendir)
+
+ return job_staging_dir
+
+
+def test_collect_filepaths(staging_dir):
+ """
+ Test that the number of valid files for `/learn` is correct.
+ i.e., the `collect_filepaths` function only selects files that are
+ 1. Not in the the excluded directories and
+ 2. Are in the valid file extensions list.
+ """
+ all_files = False
+ staging_dir_filepath = staging_dir
+ # Call the function we want to test
+ result = collect_filepaths(staging_dir_filepath, all_files)
+
+ assert len(result) == 3 # Test number of valid files
+
+ filenames = [fp.name for fp in result]
+ assert "file0.html" in filenames # Check that valid file is included
+ assert "file3.xyz" not in filenames # Check that invalid file is excluded