From f64bd7ebf53de022447df832c45c28abf43999ef Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:07:41 +0100 Subject: [PATCH] Fix typing jupyter-ai codebase (mostly) --- .../jupyter_ai_magics/models/completion.py | 1 + .../jupyter_ai_magics/py.typed | 0 .../jupyter_ai/chat_handlers/ask.py | 1 + .../jupyter_ai/chat_handlers/base.py | 27 ++++++++++--------- .../jupyter_ai/chat_handlers/default.py | 7 ++--- .../jupyter_ai/chat_handlers/fix.py | 1 + .../jupyter_ai/chat_handlers/generate.py | 3 ++- .../jupyter_ai/chat_handlers/learn.py | 14 ++++++---- .../jupyter-ai/jupyter_ai/config_manager.py | 10 ++++--- .../jupyter_ai/document_loaders/directory.py | 2 +- packages/jupyter-ai/jupyter_ai/extension.py | 2 +- packages/jupyter-ai/jupyter_ai/handlers.py | 10 +++---- packages/jupyter-ai/jupyter_ai/history.py | 4 +-- packages/jupyter-ai/jupyter_ai/models.py | 18 ++++++------- 14 files changed, 55 insertions(+), 45 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/py.typed diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py index 147f6ceec..f2ee0cd54 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py @@ -43,6 +43,7 @@ class InlineCompletionItem(BaseModel): class CompletionError(BaseModel): type: str + title: str traceback: str diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed b/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 5c3026685..2da303e5a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -72,6 +72,7 @@ async def process_message(self, message: HumanChatMessage): try: with self.pending("Searching learned documents", message): + assert self.llm_chain result = await self.llm_chain.acall({"question": query}) response = result["answer"] self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index b97015518..347bfbf83 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -13,6 +13,7 @@ Optional, Type, Union, + cast, ) from uuid import uuid4 @@ -28,6 +29,7 @@ ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider +from langchain.chains import LLMChain from langchain.pydantic_v1 import BaseModel if TYPE_CHECKING: @@ -36,8 +38,8 @@ from langchain_core.chat_history import BaseChatMessageHistory -def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]: - if preferred_dir != "": +def get_preferred_dir(root_dir: str, preferred_dir: Optional[str]) -> Optional[str]: + if preferred_dir is not None and preferred_dir != "": preferred_dir = os.path.expanduser(preferred_dir) if not preferred_dir.startswith(root_dir): preferred_dir = os.path.join(root_dir, preferred_dir) @@ -47,7 +49,7 @@ def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]: # Chat handler type, with specific attributes for each class HandlerRoutingType(BaseModel): - routing_method: ClassVar[Union[Literal["slash_command"]]] = ... + routing_method: ClassVar[Union[Literal["slash_command"]]] """The routing method that sends commands to this handler.""" @@ -83,17 +85,17 @@ class BaseChatHandler: multiple chat handler classes.""" # Class attributes - id: ClassVar[str] = ... + id: ClassVar[str] """ID for this chat handler; should be unique""" - name: ClassVar[str] = ... + name: ClassVar[str] """User-facing name of this handler""" - help: ClassVar[str] = ... + help: ClassVar[str] """What this chat handler does, which third-party models it contacts, the data it returns to the user, and so on, for display in the UI.""" - routing_type: ClassVar[HandlerRoutingType] = ... + routing_type: ClassVar[HandlerRoutingType] uses_llm: ClassVar[bool] = True """Class attribute specifying whether this chat handler uses the LLM @@ -153,9 +155,9 @@ def __init__( self.help_message_template = help_message_template self.chat_handlers = chat_handlers - self.llm = None - self.llm_params = None - self.llm_chain = None + self.llm: Optional[BaseProvider] = None + self.llm_params: Optional[dict] = None + self.llm_chain: Optional[LLMChain] = None async def on_message(self, message: HumanChatMessage): """ @@ -168,9 +170,8 @@ async def on_message(self, message: HumanChatMessage): # ensure the current slash command is supported if self.routing_type.routing_method == "slash_command": - slash_command = ( - "/" + self.routing_type.slash_id if self.routing_type.slash_id else "" - ) + routing_type = cast(SlashCommandRoutingType, self.routing_type) + slash_command = "/" + routing_type.slash_id if routing_type.slash_id else "" if slash_command in lm_provider_klass.unsupported_slash_commands: self.reply( "Sorry, the selected language model does not support this slash command." diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3e8f194b3..2bf88bafe 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -41,10 +41,10 @@ def create_llm_chain( prompt_template = llm.get_chat_prompt_template() self.llm = llm - runnable = prompt_template | llm + runnable = prompt_template | llm # type:ignore if not llm.manages_history: runnable = RunnableWithMessageHistory( - runnable=runnable, + runnable=runnable, # type:ignore[arg-type] get_session_history=self.get_llm_chat_memory, input_messages_key="input", history_messages_key="history", @@ -106,6 +106,7 @@ async def process_message(self, message: HumanChatMessage): # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. + assert self.llm_chain async for chunk in self.llm_chain.astream( {"input": message.body}, config={"configurable": {"last_human_msg": message}}, @@ -117,7 +118,7 @@ async def process_message(self, message: HumanChatMessage): stream_id = self._start_stream(human_msg=message) received_first_chunk = True - if isinstance(chunk, AIMessageChunk): + if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str): self._send_stream_chunk(stream_id, chunk.content) elif isinstance(chunk, str): self._send_stream_chunk(stream_id, chunk) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index d6ecc6d81..1056e592c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -93,6 +93,7 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() with self.pending("Analyzing error", message): + assert self.llm_chain response = await self.llm_chain.apredict( extra_instructions=extra_instructions, stop=["\nHuman:"], diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 52398eabe..a69b5ed28 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -226,7 +226,7 @@ class GenerateChatHandler(BaseChatHandler): def __init__(self, log_dir: Optional[str], *args, **kwargs): super().__init__(*args, **kwargs) self.log_dir = Path(log_dir) if log_dir else None - self.llm = None + self.llm: Optional[BaseProvider] = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -248,6 +248,7 @@ async def _generate_notebook(self, prompt: str): # Save the user input prompt, the description property is now LLM generated. outline["prompt"] = prompt + assert self.llm if self.llm.allows_concurrency: # fill the outline concurrently await afill_outline(outline, llm=self.llm, verbose=True) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 29e147f22..b846e28ea 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -223,7 +223,9 @@ async def learn_dir( } splitter = ExtensionSplitter( splitters=splitters, - default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs), + default_splitter=RecursiveCharacterTextSplitter( + **splitter_kwargs + ), # type:ignore[arg-type] ) delayed = split(path, all_files, splitter=splitter) @@ -352,7 +354,7 @@ async def aget_relevant_documents( self, query: str ) -> Coroutine[Any, Any, List[Document]]: if not self.index: - return [] + return [] # type:ignore[return-value] await self.delete_and_relearn() docs = self.index.similarity_search(query) @@ -370,12 +372,14 @@ def get_embedding_model(self): class Retriever(BaseRetriever): - learn_chat_handler: LearnChatHandler = None + learn_chat_handler: LearnChatHandler = None # type:ignore[assignment] - def _get_relevant_documents(self, query: str) -> List[Document]: + def _get_relevant_documents( + self, query: str + ) -> List[Document]: # type:ignore[override] raise NotImplementedError() - async def _aget_relevant_documents( + async def _aget_relevant_documents( # type:ignore[override] self, query: str ) -> Coroutine[Any, Any, List[Document]]: docs = await self.learn_chat_handler.aget_relevant_documents(query) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index d7674c5a2..d8072b655 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -3,7 +3,7 @@ import os import shutil import time -from typing import List, Optional, Union +from typing import List, Optional, Type, Union from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator @@ -60,7 +60,7 @@ class BlockedModelError(Exception): pass -def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider): +def _validate_provider_authn(config: GlobalConfig, provider: Type[AnyProvider]): # TODO: handle non-env auth strategies if not provider.auth_strategy or provider.auth_strategy.type != "env": return @@ -147,7 +147,7 @@ def _init_config_schema(self): os.makedirs(os.path.dirname(self.schema_path), exist_ok=True) shutil.copy(OUR_SCHEMA_PATH, self.schema_path) - def _init_validator(self) -> Validator: + def _init_validator(self) -> None: with open(OUR_SCHEMA_PATH, encoding="utf-8") as f: schema = json.loads(f.read()) Validator.check_schema(schema) @@ -364,7 +364,9 @@ def delete_api_key(self, key_name: str): config_dict["api_keys"].pop(key_name, None) self._write_config(GlobalConfig(**config_dict)) - def update_config(self, config_update: UpdateConfigRequest): + def update_config( + self, config_update: UpdateConfigRequest + ): # type:ignore[override] last_write = os.stat(self.config_path).st_mtime_ns if config_update.last_read and config_update.last_read < last_write: raise WriteConflictError( diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index 7b9b28328..130695283 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -30,7 +30,7 @@ def arxiv_to_text(id: str, output_dir: str) -> str: output path to the downloaded TeX file """ - import arxiv + import arxiv # type:ignore[import-not-found] outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex" download_filename = "downloaded-paper.tar.gz" diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a57fb0eb9..b017da0df 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -52,7 +52,7 @@ class AiExtension(ExtensionApp): name = "jupyter_ai" - handlers = [ + handlers = [ # type:ignore[assignment] (r"api/ai/api_keys/(?P\w+)", ApiKeysHandler), (r"api/ai/config/?", GlobalConfigHandler), (r"api/ai/chats/?", RootChatHandler), diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index a614e3e84..9de16f9eb 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -4,7 +4,7 @@ import uuid from asyncio import AbstractEventLoop from dataclasses import asdict -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, cast import tornado from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType @@ -42,14 +42,12 @@ from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider from jupyter_ai_magics.providers import BaseProvider - from .history import BoundChatHistory + from .history import BoundedChatHistory class ChatHistoryHandler(BaseAPIHandler): """Handler to return message history""" - _messages = [] - @property def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] @@ -103,7 +101,7 @@ def chat_history(self, new_history): self.settings["chat_history"] = new_history @property - def llm_chat_memory(self) -> "BoundChatHistory": + def llm_chat_memory(self) -> "BoundedChatHistory": return self.settings["llm_chat_memory"] @property @@ -401,7 +399,7 @@ def filter_predicate(local_model_id: str): if self.blocked_models: return model_id not in self.blocked_models else: - return model_id in self.allowed_models + return model_id in cast(List, self.allowed_models) # filter out every model w/ model ID according to allow/blocklist for provider in providers: diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 9e1064194..2b2628e5d 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -25,7 +25,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): _all_messages: List[BaseMessage] = PrivateAttr(default_factory=list) @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type:ignore[override] return self._all_messages[-self.k * 2 :] async def aget_messages(self) -> List[BaseMessage]: @@ -90,7 +90,7 @@ class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel): last_human_msg: HumanChatMessage @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type:ignore[override] return self.history.messages def add_message(self, message: BaseMessage) -> None: diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index f9098a12a..bda4d3421 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -70,8 +70,7 @@ class ChatClient(ChatUser): id: str -class AgentChatMessage(BaseModel): - type: Literal["agent"] = "agent" +class BaseAgentMessage(BaseModel): id: str time: float body: str @@ -89,7 +88,11 @@ class AgentChatMessage(BaseModel): """ -class AgentStreamMessage(AgentChatMessage): +class AgentChatMessage(BaseAgentMessage): + type: Literal["agent"] = "agent" + + +class AgentStreamMessage(BaseAgentMessage): type: Literal["agent-stream"] = "agent-stream" complete: bool # other attrs inherited from `AgentChatMessage` @@ -138,15 +141,13 @@ class PendingMessage(BaseModel): class ClosePendingMessage(BaseModel): - type: Literal["pending"] = "close-pending" + type: Literal["close-pending"] = "close-pending" id: str # the type of messages being broadcast to clients ChatMessage = Union[ - AgentChatMessage, - HumanChatMessage, - AgentStreamMessage, + AgentChatMessage, HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage ] @@ -164,8 +165,7 @@ class ConnectionMessage(BaseModel): Message = Union[ - AgentChatMessage, - HumanChatMessage, + ChatMessage, ConnectionMessage, ClearMessage, PendingMessage,