diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index bbe00777e..6137609b6 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -108,8 +108,9 @@ def split_document(document, splitter: TextSplitter) -> List[Document]: def flatten(*chunk_lists): return list(itertools.chain(*chunk_lists)) -# Selects eligible files, i.e., -# 1. Files not in excluded directories, and + +# Selects eligible files, i.e., +# 1. Files not in excluded directories, and # 2. Files that are in the valid file extensions list # Called from the `split` function. def collect_files(path, all_files: bool): @@ -127,7 +128,11 @@ def collect_files(path, all_files: bool): ] filenames = [f for f in filenames if not f[0] == "."] filepaths += [Path(os.path.join(dir, filename)) for filename in filenames] - filepaths = [fp for fp in filepaths if fp.suffix.lower() in {j.lower() for j in SUPPORTED_EXTS}] + filepaths = [ + fp + for fp in filepaths + if fp.suffix.lower() in {j.lower() for j in SUPPORTED_EXTS} + ] return filepaths diff --git a/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/__init__.py b/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/__init__.py index 7df157d78..5fbde7b8d 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/__init__.py @@ -1 +1 @@ -"""Tests for the collect_files and split functions in directory.py.""" \ No newline at end of file +"""Tests for the collect_files and split functions in directory.py.""" diff --git a/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/test_eligible_files.py b/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/test_eligible_files.py index 075528c10..16627e48a 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/test_eligible_files.py +++ b/packages/jupyter-ai/jupyter_ai/tests/doc_loaders/test_eligible_files.py @@ -1,30 +1,39 @@ """ -Test that the collect_files function only selects files that are -1. Not in the the excluded directories and +Test that the collect_files function only selects files that are +1. Not in the the excluded directories and 2. Are in the valid file extensions list. """ import os import unittest + from jupyter_ai.document_loaders.directory import collect_files + class TestCollectFiles(unittest.TestCase): # Prepare temp directory for test os.mkdir("TestDir") path = os.path.join(os.getcwd(), "TestDir") - test_dir_contents = {'0' : ['file0.html', '.hidden_file.pdf'], # top level folder, 1 valid file - 'subdir' : ['file1.txt','.hidden_file.txt','file2.py','file3.xyz'], # subfolder, 2 valid files - '.hidden_dir' : ['file3.csv', 'file4.pdf']} # hidden subfolder, no valid files + test_dir_contents = { + "0": ["file0.html", ".hidden_file.pdf"], # top level folder, 1 valid file + "subdir": [ + "file1.txt", + ".hidden_file.txt", + "file2.py", + "file3.xyz", + ], # subfolder, 2 valid files + ".hidden_dir": ["file3.csv", "file4.pdf"], + } # hidden subfolder, no valid files for folder in test_dir_contents: os.chdir(path) - if folder != '0': + if folder != "0": os.mkdir(folder) - d = os.path.join(path,folder) - else: + d = os.path.join(path, folder) + else: d = path for file in test_dir_contents[folder]: filepath = os.path.join(d, file) - open(filepath, 'a') + open(filepath, "a") # Test that the number of valid files for `/learn` is correct def test_collect_files(self): @@ -35,8 +44,10 @@ def test_collect_files(self): # Clean up temp directory from shutil import rmtree + rmtree(self.path) os.chdir(os.path.split(self.path)[0]) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()