diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index ce4048b8f..0f10b0147 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -118,16 +118,22 @@ async def process_message(self, message: HumanChatMessage): if remote_type == "arxiv": try: id = args.path[0] - args.path = [ - arxiv_to_text(id) - ] # call the function in `directory.py`` + args.path = [arxiv_to_text(id, self.root_dir)] self.reply( - f"Processing arxiv file id {id}, saved in {args.path[0]}.", + f"Learning arxiv file with id **{id}**, saved in **{args.path[0]}**.", message, ) + except ModuleNotFoundError as e: + self.log.error(e) + self.reply( + "No `arxiv` package found. " "Install with `pip install arxiv`." + ) + return except Exception as e: + self.log.error(e) self.reply( - f"""The arXiv file could not be processed. Check the paper ID ({id}). Or, verify that the `arxiv` package is installed.""" + "An error occurred while processing the arXiv file. " + f"Please verify that the arxiv id {id} is correct." ) return diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index 083c260b7..e493fb385 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -1,8 +1,8 @@ -import datetime import hashlib import itertools import os import tarfile +from datetime import datetime from pathlib import Path from typing import List @@ -12,31 +12,49 @@ from langchain_community.document_loaders import PyPDFLoader -# Download a single tar file from arXiv and store in a temp folder for RAG, then run learn on it. -def arxiv_to_text(id): # id is numbers after "arXiv" in arXiv:xxxx.xxxxx +def arxiv_to_text(id: str, output_dir: str) -> str: + """Downloads and extracts single tar file from arXiv. + Combines the TeX components into a single file. + + Parameters + ---------- + id : str + id for the paper, numbers after "arXiv" in arXiv:xxxx.xxxxx + + output_dir : str + directory to save the output file + + Returns + ------- + output: str + output path to the downloaded TeX file + """ + import arxiv - # Get the paper from arxiv - outfile = id + datetime.datetime.now().strftime("_%Y-%m-%d-%H-%M") + ".tex" + outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex" + download_filename = "downloaded-paper.tar.gz" + output_path = os.path.join(output_dir, outfile) + paper = next(arxiv.Client().results(arxiv.Search(id_list=[id]))) - paper.download_source(dirpath="", filename="downloaded-paper.tar.gz") - # Extract tex files from downloaded tar file - with tarfile.open("downloaded-paper.tar.gz") as tar: + paper.download_source(filename=download_filename) + + with tarfile.open(download_filename) as tar: tex_list = [] for member in tar: if member.isfile() and member.name.lower().endswith(".tex"): tex_list.append(member.name) tar.extract(member, path="") - # Concatenate all tex files - with open(outfile, "w") as wfd: + + with open(output_path, "w") as w: for f in tex_list: - with open(f) as infile: - wfd.write(infile.read()) + with open(f) as tex: + w.write(tex.read()) os.remove(f) - outfile_path = os.path.realpath(outfile) - os.remove("downloaded-paper.tar.gz") - return outfile_path + os.remove(download_filename) + + return output_path # Uses pypdf which is used by PyPDFLoader from langchain