diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index b93685c0e..8ac6dd2cd 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -4,7 +4,6 @@ from typing import Any, Awaitable, Coroutine, List, Optional, Tuple from dask.distributed import Client as DaskClient -from jupyter_ai.config_manager import ConfigManager from jupyter_ai.document_loaders.directory import get_embeddings, split from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter from jupyter_ai.models import ( @@ -31,6 +30,8 @@ class LearnChatHandler(BaseChatHandler): + _retriever: BaseRetriever + def __init__( self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs ): @@ -61,8 +62,26 @@ def __init__( if not os.path.exists(INDEX_SAVE_DIR): os.makedirs(INDEX_SAVE_DIR) + self._init_retriever() self._load() + def _init_retriever(self): + class Retriever(BaseRetriever): + def _get_relevant_documents(_) -> List[Document]: + raise NotImplementedError() + + async def _aget_relevant_documents( + _, query: str + ) -> Coroutine[Any, Any, List[Document]]: + # here `self` resolves to the LearnChatHandler parent. + return await self.aget_relevant_documents(query) + + self._retriever = Retriever() + + @property + def retriever(self): + return self._retriever + def _load(self): """Loads the vector store.""" embeddings = self.get_embedding_model() @@ -272,16 +291,6 @@ def load_metadata(self): j = json.loads(f.read()) self.metadata = IndexMetadata(**j) - async def aget_relevant_documents( - self, query: str - ) -> Coroutine[Any, Any, List[Document]]: - if not self.index: - return [] - - await self.delete_and_relearn() - docs = self.index.similarity_search(query) - return docs - def get_embedding_provider(self): return self.config_manager.em_provider, self.config_manager.em_provider_params @@ -292,15 +301,17 @@ def get_embedding_model(self): return em_provider_cls(**em_provider_args) - -class Retriever(BaseRetriever): - learn_chat_handler: LearnChatHandler = None - def _get_relevant_documents(self, query: str) -> List[Document]: raise NotImplementedError() - async def _aget_relevant_documents( + async def aget_relevant_documents( self, query: str ) -> Coroutine[Any, Any, List[Document]]: - docs = await self.learn_chat_handler.aget_relevant_documents(query) + """This method defines the behavior of `self.retriever`, the LangChain + retriever object used by the AskChatHandler.""" + if not self.index: + return [] + + await self.delete_and_relearn() + docs = self.index.similarity_search(query) return docs diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 50865ed96..59c04a0a4 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,7 +1,6 @@ import time from dask.distributed import Client as DaskClient -from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from traitlets import List, Unicode @@ -124,8 +123,9 @@ def initialize_settings(self): dask_client_future=dask_client_future, ) help_chat_handler = HelpChatHandler(**chat_handler_kwargs) - retriever = Retriever(learn_chat_handler=learn_chat_handler) - ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever) + ask_chat_handler = AskChatHandler( + **chat_handler_kwargs, retriever=learn_chat_handler.retriever + ) self.settings["jai_chat_handlers"] = { "default": default_chat_handler, "/ask": ask_chat_handler,