diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 1556cc91c..2906fc074 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -464,6 +464,13 @@ use the `-a` or `--all-files` option. /learn -a ``` +### Learning arXiv files +`/learn` command also provides downloading and processing papers from the [arXiv](https://arxiv.org/) repository. You will need to install the `arxiv` python package for this feature to work. Run `pip install arxiv` to install the `arxiv` package. + +``` +/learn -r arxiv 2404.18558 +``` + ### Additional chat commands To clear the chat panel, use the `/clear` command. This does not reset the AI model; the model may still remember previous messages that you sent it, and it may use them to inform its responses. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index f9348a21f..0f10b0147 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -4,7 +4,7 @@ from typing import Any, Coroutine, List, Optional, Tuple from dask.distributed import Client as DaskClient -from jupyter_ai.document_loaders.directory import get_embeddings, split +from jupyter_ai.document_loaders.directory import arxiv_to_text, get_embeddings, split from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter from jupyter_ai.models import ( DEFAULT_CHUNK_OVERLAP, @@ -44,6 +44,9 @@ def __init__(self, *args, **kwargs): self.parser.add_argument("-v", "--verbose", action="store_true") self.parser.add_argument("-d", "--delete", action="store_true") self.parser.add_argument("-l", "--list", action="store_true") + self.parser.add_argument( + "-r", "--remote", action="store", default=None, type=str + ) self.parser.add_argument( "-c", "--chunk-size", action="store", default=DEFAULT_CHUNK_SIZE, type=int ) @@ -110,6 +113,30 @@ async def process_message(self, message: HumanChatMessage): self.reply(self._build_list_response()) return + if args.remote: + remote_type = args.remote.lower() + if remote_type == "arxiv": + try: + id = args.path[0] + args.path = [arxiv_to_text(id, self.root_dir)] + self.reply( + f"Learning arxiv file with id **{id}**, saved in **{args.path[0]}**.", + message, + ) + except ModuleNotFoundError as e: + self.log.error(e) + self.reply( + "No `arxiv` package found. " "Install with `pip install arxiv`." + ) + return + except Exception as e: + self.log.error(e) + self.reply( + "An error occurred while processing the arXiv file. " + f"Please verify that the arxiv id {id} is correct." + ) + return + # Make sure the path exists. if not len(args.path) == 1: self.reply(f"{self.parser.format_usage()}", message) diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index b8f9de3bb..e493fb385 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -1,6 +1,8 @@ import hashlib import itertools import os +import tarfile +from datetime import datetime from pathlib import Path from typing import List @@ -10,6 +12,51 @@ from langchain_community.document_loaders import PyPDFLoader +def arxiv_to_text(id: str, output_dir: str) -> str: + """Downloads and extracts single tar file from arXiv. + Combines the TeX components into a single file. + + Parameters + ---------- + id : str + id for the paper, numbers after "arXiv" in arXiv:xxxx.xxxxx + + output_dir : str + directory to save the output file + + Returns + ------- + output: str + output path to the downloaded TeX file + """ + + import arxiv + + outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex" + download_filename = "downloaded-paper.tar.gz" + output_path = os.path.join(output_dir, outfile) + + paper = next(arxiv.Client().results(arxiv.Search(id_list=[id]))) + paper.download_source(filename=download_filename) + + with tarfile.open(download_filename) as tar: + tex_list = [] + for member in tar: + if member.isfile() and member.name.lower().endswith(".tex"): + tex_list.append(member.name) + tar.extract(member, path="") + + with open(output_path, "w") as w: + for f in tex_list: + with open(f) as tex: + w.write(tex.read()) + os.remove(f) + + os.remove(download_filename) + + return output_path + + # Uses pypdf which is used by PyPDFLoader from langchain def pdf_to_text(path): pages = PyPDFLoader(path) @@ -50,6 +97,7 @@ def path_to_doc(path): ".txt", ".html", ".pdf", + ".tex", # added for raw latex files from arxiv } diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index ec7959975..4328c9aca 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -53,7 +53,7 @@ test = [ dev = ["jupyter_ai_magics[dev]"] -all = ["jupyter_ai_magics[all]", "pypdf"] +all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"] [tool.hatch.version] source = "nodejs"