Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run mypy on CI, fix or ignore typing issues #987

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Python Unit Tests
name: Python Tests

# suppress warning raised by https://github.com/jupyter/jupyter_core/pull/292
env:
Expand All @@ -12,7 +12,7 @@ on:

jobs:
unit-tests:
name: Linux
name: Unit tests
runs-on: ubuntu-latest
steps:
- name: Checkout
Expand All @@ -28,3 +28,22 @@ jobs:
run: |
set -eux
pytest -vv -r ap --cov jupyter_ai

typing-tests:
name: Typing test
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Base Setup
uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1

- name: Install extension dependencies and build the extension
run: ./scripts/install.sh

- name: Run mypy
run: |
set -eux
mypy --version
mypy packages/jupyter-ai
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class InlineCompletionItem(BaseModel):

class CompletionError(BaseModel):
type: str
title: str
traceback: str


Expand Down
Empty file.
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def process_message(self, message: HumanChatMessage):

try:
with self.pending("Searching learned documents", message):
assert self.llm_chain
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
Expand Down
27 changes: 14 additions & 13 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Type,
Union,
cast,
)
from uuid import uuid4

Expand All @@ -28,6 +29,7 @@
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.pydantic_v1 import BaseModel

if TYPE_CHECKING:
Expand All @@ -36,8 +38,8 @@
from langchain_core.chat_history import BaseChatMessageHistory


def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:
if preferred_dir != "":
def get_preferred_dir(root_dir: str, preferred_dir: Optional[str]) -> Optional[str]:
if preferred_dir is not None and preferred_dir != "":
preferred_dir = os.path.expanduser(preferred_dir)
if not preferred_dir.startswith(root_dir):
preferred_dir = os.path.join(root_dir, preferred_dir)
Expand All @@ -47,7 +49,7 @@ def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:

# Chat handler type, with specific attributes for each
class HandlerRoutingType(BaseModel):
routing_method: ClassVar[Union[Literal["slash_command"]]] = ...
routing_method: ClassVar[Union[Literal["slash_command"]]]
"""The routing method that sends commands to this handler."""


Expand Down Expand Up @@ -83,17 +85,17 @@ class BaseChatHandler:
multiple chat handler classes."""

# Class attributes
id: ClassVar[str] = ...
id: ClassVar[str]
"""ID for this chat handler; should be unique"""

name: ClassVar[str] = ...
name: ClassVar[str]
"""User-facing name of this handler"""

help: ClassVar[str] = ...
help: ClassVar[str]
"""What this chat handler does, which third-party models it contacts,
the data it returns to the user, and so on, for display in the UI."""

routing_type: ClassVar[HandlerRoutingType] = ...
routing_type: ClassVar[HandlerRoutingType]

uses_llm: ClassVar[bool] = True
"""Class attribute specifying whether this chat handler uses the LLM
Expand Down Expand Up @@ -153,9 +155,9 @@ def __init__(
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers

self.llm = None
self.llm_params = None
self.llm_chain = None
self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
self.llm_chain: Optional[LLMChain] = None

async def on_message(self, message: HumanChatMessage):
"""
Expand All @@ -168,9 +170,8 @@ async def on_message(self, message: HumanChatMessage):

# ensure the current slash command is supported
if self.routing_type.routing_method == "slash_command":
slash_command = (
"/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
)
routing_type = cast(SlashCommandRoutingType, self.routing_type)
slash_command = "/" + routing_type.slash_id if routing_type.slash_id else ""
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
Expand Down
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def create_llm_chain(
prompt_template = llm.get_chat_prompt_template()
self.llm = llm

runnable = prompt_template | llm
runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
runnable = RunnableWithMessageHistory(
runnable=runnable,
runnable=runnable, # type:ignore[arg-type]
get_session_history=self.get_llm_chat_memory,
input_messages_key="input",
history_messages_key="history",
Expand Down Expand Up @@ -106,6 +106,7 @@ async def process_message(self, message: HumanChatMessage):
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
{"input": message.body},
config={"configurable": {"last_human_msg": message}},
Expand All @@ -117,7 +118,7 @@ async def process_message(self, message: HumanChatMessage):
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def process_message(self, message: HumanChatMessage):

self.get_llm_chain()
with self.pending("Analyzing error", message):
assert self.llm_chain
response = await self.llm_chain.apredict(
extra_instructions=extra_instructions,
stop=["\nHuman:"],
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class GenerateChatHandler(BaseChatHandler):
def __init__(self, log_dir: Optional[str], *args, **kwargs):
super().__init__(*args, **kwargs)
self.log_dir = Path(log_dir) if log_dir else None
self.llm = None
self.llm: Optional[BaseProvider] = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -248,6 +248,7 @@ async def _generate_notebook(self, prompt: str):
# Save the user input prompt, the description property is now LLM generated.
outline["prompt"] = prompt

assert self.llm
if self.llm.allows_concurrency:
# fill the outline concurrently
await afill_outline(outline, llm=self.llm, verbose=True)
Expand Down
14 changes: 9 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ async def learn_dir(
}
splitter = ExtensionSplitter(
splitters=splitters,
default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs),
default_splitter=RecursiveCharacterTextSplitter(
**splitter_kwargs # type:ignore[arg-type]
),
)

delayed = split(path, all_files, splitter=splitter)
Expand Down Expand Up @@ -352,7 +354,7 @@ async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
if not self.index:
return []
return [] # type:ignore[return-value]

await self.delete_and_relearn()
docs = self.index.similarity_search(query)
Expand All @@ -370,12 +372,14 @@ def get_embedding_model(self):


class Retriever(BaseRetriever):
learn_chat_handler: LearnChatHandler = None
learn_chat_handler: LearnChatHandler = None # type:ignore[assignment]

def _get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents( # type:ignore[override]
self, query: str
) -> List[Document]:
raise NotImplementedError()

async def _aget_relevant_documents(
async def _aget_relevant_documents( # type:ignore[override]
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
docs = await self.learn_chat_handler.aget_relevant_documents(query)
Expand Down
8 changes: 4 additions & 4 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import List, Optional, Union
from typing import List, Optional, Type, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand Down Expand Up @@ -60,7 +60,7 @@ class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
def _validate_provider_authn(config: GlobalConfig, provider: Type[AnyProvider]):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
return
Expand Down Expand Up @@ -147,7 +147,7 @@ def _init_config_schema(self):
os.makedirs(os.path.dirname(self.schema_path), exist_ok=True)
shutil.copy(OUR_SCHEMA_PATH, self.schema_path)

def _init_validator(self) -> Validator:
def _init_validator(self) -> None:
with open(OUR_SCHEMA_PATH, encoding="utf-8") as f:
schema = json.loads(f.read())
Validator.check_schema(schema)
Expand Down Expand Up @@ -364,7 +364,7 @@ def delete_api_key(self, key_name: str):
config_dict["api_keys"].pop(key_name, None)
self._write_config(GlobalConfig(**config_dict))

def update_config(self, config_update: UpdateConfigRequest):
def update_config(self, config_update: UpdateConfigRequest): # type:ignore
last_write = os.stat(self.config_path).st_mtime_ns
if config_update.last_read and config_update.last_read < last_write:
raise WriteConflictError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def arxiv_to_text(id: str, output_dir: str) -> str:
output path to the downloaded TeX file
"""

import arxiv
import arxiv # type:ignore[import-not-found,import-untyped]

outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex"
download_filename = "downloaded-paper.tar.gz"
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
handlers = [ # type:ignore[assignment]
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/config/?", GlobalConfigHandler),
(r"api/ai/chats/?", RootChatHandler),
Expand Down
15 changes: 7 additions & 8 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, cast

import tornado
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -42,14 +42,12 @@
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider

from .history import BoundChatHistory
from .history import BoundedChatHistory


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""

_messages = []

@property
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]
Expand Down Expand Up @@ -103,7 +101,7 @@ def chat_history(self, new_history):
self.settings["chat_history"] = new_history

@property
def llm_chat_memory(self) -> "BoundChatHistory":
def llm_chat_memory(self) -> "BoundedChatHistory":
return self.settings["llm_chat_memory"]

@property
Expand Down Expand Up @@ -145,6 +143,7 @@ def get_chat_user(self) -> ChatUser:
environment."""
# Get a dictionary of all loaded extensions.
# (`serverapp` is a property on all `JupyterHandler` subclasses)
assert self.serverapp
extensions = self.serverapp.extension_manager.extensions
collaborative = (
"jupyter_collaboration" in extensions
Expand Down Expand Up @@ -401,7 +400,7 @@ def filter_predicate(local_model_id: str):
if self.blocked_models:
return model_id not in self.blocked_models
else:
return model_id in self.allowed_models
return model_id in cast(List, self.allowed_models)

# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
Expand Down Expand Up @@ -514,7 +513,7 @@ def post(self):

class ApiKeysHandler(BaseAPIHandler):
@property
def config_manager(self) -> ConfigManager:
def config_manager(self) -> ConfigManager: # type:ignore[override]
return self.settings["jai_config_manager"]

@web.authenticated
Expand All @@ -529,7 +528,7 @@ class SlashCommandsInfoHandler(BaseAPIHandler):
"""List slash commands that are currently available to the user."""

@property
def config_manager(self) -> ConfigManager:
def config_manager(self) -> ConfigManager: # type:ignore[override]
return self.settings["jai_config_manager"]

@property
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
_all_messages: List[BaseMessage] = PrivateAttr(default_factory=list)

@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type:ignore[override]
if self.k is None:
return self._all_messages
return self._all_messages[-self.k * 2 :]
Expand Down Expand Up @@ -92,7 +92,7 @@ class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel):
last_human_msg: HumanChatMessage

@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type:ignore[override]
return self.history.messages

def add_message(self, message: BaseMessage) -> None:
Expand Down
Loading
Loading