diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 38390a44c..21c5ae6bb 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -4,7 +4,7 @@ from typing import Any, Coroutine, List, Optional, Tuple from dask.distributed import Client as DaskClient -from jupyter_ai.document_loaders.directory import get_embeddings, split +from jupyter_ai.document_loaders.directory import get_embeddings, split, arxiv_to_text from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter from jupyter_ai.models import ( DEFAULT_CHUNK_OVERLAP, @@ -44,6 +44,9 @@ def __init__(self, *args, **kwargs): self.parser.add_argument("-v", "--verbose", action="store_true") self.parser.add_argument("-d", "--delete", action="store_true") self.parser.add_argument("-l", "--list", action="store_true") + self.parser.add_argument( + "-r", "--remote", action="store" , default=None, type=str + ) self.parser.add_argument( "-c", "--chunk-size", action="store", default=DEFAULT_CHUNK_SIZE, type=int ) @@ -107,6 +110,17 @@ async def process_message(self, message: HumanChatMessage): self.reply(self._build_list_response()) return + if args.remote: + remote_type = args.remote.lower() + if remote_type=="arxiv": + try: + id = args.path[0] + args.path = [arxiv_to_text(id)] # call the function in `directory.py`` + self.reply(f"Processing arxiv file id {id}, saved in {args.path[0]}.", message) + except Exception as e: + self.reply(f"""The arXiv file could not be processed. Check the paper ID ({id}). Or, verify that the `arxiv` package is installed.""") + return + # Make sure the path exists. if not len(args.path) == 1: self.reply(f"{self.parser.format_usage()}", message) diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index 561f00a1c..bf70b89e2 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -5,10 +5,42 @@ from typing import List import dask -from langchain.document_loaders import PyPDFLoader +from langchain.document_loaders import PyPDFLoader, ArxivLoader from langchain.schema import Document from langchain.text_splitter import TextSplitter +import tarfile +import shutil + +# Download a single tar file from arXiv and store in a temp folder for RAG, then run learn on it. +try: + import arxiv +except Exception as e: + print("Missing package: arxiv") + +def arxiv_to_text(id): # id is numbers after "arXiv" in arXiv:xxxx.xxxxx + # Get the paper from arxiv + outfile = id + ".tex" + temp_dir = "downloads_temp" + if not os.path.isdir(temp_dir): + os.mkdir(temp_dir) + client = arxiv.Client() + paper = next(arxiv.Client().results(arxiv.Search(id_list=[id]))) + paper.download_source(dirpath=temp_dir, filename="downloaded-paper.tar.gz") + # Extract downloaded tar file + tar = tarfile.open(temp_dir+"/downloaded-paper.tar.gz") + tar.extractall(temp_dir) + tar.close() + tex_list = os.listdir(temp_dir) + tex_list = [j for j in tex_list if j.lower().endswith('.tex')] + with open(outfile,'wb') as wfd: + for f in tex_list: + with open(temp_dir+"/"+f,'rb') as fd: + shutil.copyfileobj(fd, wfd) + + outfile_path = os.path.realpath(outfile) + shutil.rmtree(temp_dir) # Delete the temp folder but not the downloaded latex files + return outfile_path # Uses pypdf which is used by PyPDFLoader from langchain def pdf_to_text(path): @@ -50,6 +82,7 @@ def path_to_doc(path): ".txt", ".html", ".pdf", + ".tex", } diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index f5eb5e98f..68a770a1b 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -54,7 +54,7 @@ test = [ dev = ["jupyter_ai_magics[dev]"] -all = ["jupyter_ai_magics[all]", "pypdf"] +all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"] [tool.hatch.version] source = "nodejs"