Skip to content

Commit

Permalink
make Retriever an attribute on LearnChatHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Nov 7, 2023
1 parent 831c42b commit 1a63949
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
45 changes: 28 additions & 17 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Awaitable, Coroutine, List, Optional, Tuple

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai.document_loaders.directory import get_embeddings, split
from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
from jupyter_ai.models import (
Expand All @@ -31,6 +30,8 @@


class LearnChatHandler(BaseChatHandler):
_retriever: BaseRetriever

def __init__(
self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs
):
Expand Down Expand Up @@ -61,8 +62,26 @@ def __init__(
if not os.path.exists(INDEX_SAVE_DIR):
os.makedirs(INDEX_SAVE_DIR)

self._init_retriever()
self._load()

def _init_retriever(self):
class Retriever(BaseRetriever):
def _get_relevant_documents(_) -> List[Document]:
raise NotImplementedError()

async def _aget_relevant_documents(
_, query: str
) -> Coroutine[Any, Any, List[Document]]:
# here `self` resolves to the LearnChatHandler parent.
return await self.aget_relevant_documents(query)

self._retriever = Retriever()

@property
def retriever(self):
return self._retriever

def _load(self):
"""Loads the vector store."""
embeddings = self.get_embedding_model()
Expand Down Expand Up @@ -272,16 +291,6 @@ def load_metadata(self):
j = json.loads(f.read())
self.metadata = IndexMetadata(**j)

async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
if not self.index:
return []

await self.delete_and_relearn()
docs = self.index.similarity_search(query)
return docs

def get_embedding_provider(self):
return self.config_manager.em_provider, self.config_manager.em_provider_params

Expand All @@ -292,15 +301,17 @@ def get_embedding_model(self):

return em_provider_cls(**em_provider_args)


class Retriever(BaseRetriever):
learn_chat_handler: LearnChatHandler = None

def _get_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError()

async def _aget_relevant_documents(
async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
docs = await self.learn_chat_handler.aget_relevant_documents(query)
"""This method defines the behavior of `self.retriever`, the LangChain
retriever object used by the AskChatHandler."""
if not self.index:
return []

await self.delete_and_relearn()
docs = self.index.similarity_search(query)
return docs
6 changes: 3 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import time

from dask.distributed import Client as DaskClient
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from traitlets import List, Unicode
Expand Down Expand Up @@ -124,8 +123,9 @@ def initialize_settings(self):
dask_client_future=dask_client_future,
)
help_chat_handler = HelpChatHandler(**chat_handler_kwargs)
retriever = Retriever(learn_chat_handler=learn_chat_handler)
ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever)
ask_chat_handler = AskChatHandler(
**chat_handler_kwargs, retriever=learn_chat_handler.retriever
)
self.settings["jai_chat_handlers"] = {
"default": default_chat_handler,
"/ask": ask_chat_handler,
Expand Down

0 comments on commit 1a63949

Please sign in to comment.