From 5efb6f2da07bee2683bb5cd0dfb58c45eac56820 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 25 Jun 2024 15:55:24 -0700 Subject: [PATCH] pre-commit --- .gitignore | 1 - .../jupyter_ai_magics/providers.py | 7 ++-- .../jupyter_ai_test/_version.py | 2 +- .../jupyter_ai_test/test_llms.py | 12 +++--- .../jupyter_ai_test/test_providers.py | 10 ++--- .../jupyter_ai_test/test_slash_commands.py | 4 +- .../jupyter_ai/chat_handlers/base.py | 2 +- .../jupyter_ai/chat_handlers/default.py | 31 +++++++++------- packages/jupyter-ai/jupyter_ai/handlers.py | 37 ++++++++++++------- packages/jupyter-ai/jupyter_ai/history.py | 6 ++- packages/jupyter-ai/jupyter_ai/models.py | 7 +++- 11 files changed, 67 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index 8ea61ccdc..79f88f5fd 100644 --- a/.gitignore +++ b/.gitignore @@ -134,4 +134,3 @@ dev.sh .yarn .conda/ - diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index bef7094e0..8c8a1f917 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -41,9 +41,8 @@ Together, ) from langchain_community.llms.sagemaker_endpoint import LLMContentHandler - -from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM # this is necessary because `langchain.pydantic_v1.main` does not include # `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main` @@ -456,14 +455,14 @@ def _supports_sync_streaming(self): return not (self.__class__._stream is BaseChatModel._stream) else: return not (self.__class__._stream is BaseLLM._stream) - + @property def _supports_async_streaming(self): if self.is_chat_provider: return not (self.__class__._astream is BaseChatModel._astream) else: return not (self.__class__._astream is BaseLLM._astream) - + @property def supports_streaming(self): return self._supports_sync_streaming or self._supports_async_streaming diff --git a/packages/jupyter-ai-test/jupyter_ai_test/_version.py b/packages/jupyter-ai-test/jupyter_ai_test/_version.py index 133868ac5..1013248da 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/_version.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/_version.py @@ -1,4 +1,4 @@ # This file is auto-generated by Hatchling. As such, do not: # - modify # - track in version control e.g. be sure to add to .gitignore -__version__ = VERSION = '0.1.0' +__version__ = VERSION = "0.1.0" diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py index 2e328eb2c..c7c72666b 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -1,10 +1,11 @@ import time -from typing import Any, List, Optional, Iterator +from typing import Any, Iterator, List, Optional from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs.generation import GenerationChunk + class TestLLM(LLM): model_id: str = "test" @@ -21,7 +22,8 @@ def _call( ) -> str: time.sleep(3) return f"Hello! This is a dummy response from a test LLM." - + + class TestLLMWithStreaming(LLM): model_id: str = "test" @@ -47,9 +49,9 @@ def _stream( **kwargs: Any, ) -> Iterator[GenerationChunk]: time.sleep(5) - yield GenerationChunk(text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n") + yield GenerationChunk( + text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n" + ) for i in range(1, 101): time.sleep(0.5) yield GenerationChunk(text=f"{i}, ") - - \ No newline at end of file diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py index edbdff2ad..f2803deec 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py @@ -1,4 +1,5 @@ from typing import ClassVar, List + from jupyter_ai import AuthStrategy, BaseProvider, Field from .test_llms import TestLLM, TestLLMWithStreaming @@ -11,9 +12,7 @@ class TestProvider(BaseProvider, TestLLM): name: ClassVar[str] = "Test Provider" """User-facing name of this provider.""" - models: ClassVar[List[str]] = [ - "test" - ] + models: ClassVar[List[str]] = ["test"] """List of supported models by their IDs. For registry providers, this will be just ["*"].""" @@ -49,9 +48,7 @@ class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming): name: ClassVar[str] = "Test Provider (streaming)" """User-facing name of this provider.""" - models: ClassVar[List[str]] = [ - "test" - ] + models: ClassVar[List[str]] = ["test"] """List of supported models by their IDs. For registry providers, this will be just ["*"].""" @@ -78,4 +75,3 @@ class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming): fields: ClassVar[List[Field]] = [] """User inputs expected by this provider when initializing it. Each `Field` `f` should be passed in the constructor as a keyword argument, keyed by `f.key`.""" - diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py index 876d21372..f82bd5531 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py @@ -1,18 +1,20 @@ from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage + class TestSlashCommand(BaseChatHandler): """ A test slash command implementation that developers should build from. The string used to invoke this command is set by the `slash_id` keyword argument in the `routing_type` attribute. The command is mainly implemented in the `process_message()` method. See built-in implementations under - `jupyter_ai/handlers` for further reference. + `jupyter_ai/handlers` for further reference. The provider is made available to Jupyter AI by the entry point declared in `pyproject.toml`. If this class or parent module is renamed, make sure the update the entry point there as well. """ + id = "test" name = "Test" help = "A test slash command." diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 6bee9ec20..fd412eed7 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -197,7 +197,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): handler.broadcast_message(agent_msg) break - + @property def persona(self): return self.config_manager.persona diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 055df1afb..624b43699 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,15 +1,19 @@ +import time from typing import Dict, Type from uuid import uuid4 -import time -from jupyter_ai.models import HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage +from jupyter_ai.models import ( + AgentStreamChunkMessage, + AgentStreamMessage, + HumanChatMessage, +) from jupyter_ai_magics.providers import BaseProvider from langchain.memory import ConversationBufferWindowMemory -from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.messages import AIMessageChunk +from langchain_core.runnables.history import RunnableWithMessageHistory -from .base import BaseChatHandler, SlashCommandRoutingType from ..history import BoundedChatHistory +from .base import BaseChatHandler, SlashCommandRoutingType class DefaultChatHandler(BaseChatHandler): @@ -48,9 +52,8 @@ def create_llm_chain( input_messages_key="input", history_messages_key="history", ) - - self.llm_chain = runnable + self.llm_chain = runnable def _start_stream(self, human_msg: HumanChatMessage) -> str: """ @@ -64,7 +67,7 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: body="", reply_to=human_msg.id, persona=self.persona, - complete=False + complete=False, ) for handler in self._root_chat_handlers.values(): @@ -75,16 +78,14 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: break return stream_id - + def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False): """ Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, - content=content, - stream_complete=complete + id=stream_id, content=content, stream_complete=complete ) for handler in self._root_chat_handlers.values(): @@ -93,7 +94,6 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals handler.broadcast_message(stream_chunk_msg) break - async def process_message(self, message: HumanChatMessage): self.get_llm_chain() @@ -105,7 +105,10 @@ async def process_message(self, message: HumanChatMessage): # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. - async for chunk in self.llm_chain.astream({ "input": message.body }, config={"configurable": {"session_id": "static_session"}}): + async for chunk in self.llm_chain.astream( + {"input": message.body}, + config={"configurable": {"session_id": "static_session"}}, + ): if not received_first_chunk: # when receiving the first chunk, close the pending message and # start the stream. @@ -120,6 +123,6 @@ async def process_message(self, message: HumanChatMessage): else: self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") break - + # complete stream after all chunks have been streamed self._send_stream_chunk(stream_id, "", complete=True) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 6ba3560f5..1ea2edee0 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -17,13 +17,14 @@ from .models import ( AgentChatMessage, - AgentStreamMessage, AgentStreamChunkMessage, + AgentStreamMessage, ChatClient, ChatHistory, ChatMessage, ChatRequest, ChatUser, + ClosePendingMessage, ConnectionMessage, HumanChatMessage, ListProvidersEntry, @@ -32,7 +33,6 @@ ListSlashCommandsResponse, Message, PendingMessage, - ClosePendingMessage, UpdateConfigRequest, ) @@ -49,7 +49,7 @@ class ChatHistoryHandler(BaseAPIHandler): @property def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] - + @property def pending_messages(self) -> List[PendingMessage]: return self.settings["pending_messages"] @@ -57,8 +57,7 @@ def pending_messages(self) -> List[PendingMessage]: @tornado.web.authenticated async def get(self): history = ChatHistory( - messages=self.chat_history, - pending_messages=self.pending_messages + messages=self.chat_history, pending_messages=self.pending_messages ) self.finish(history.json()) @@ -106,7 +105,7 @@ def loop(self) -> AbstractEventLoop: @property def pending_messages(self) -> List[PendingMessage]: return self.settings["pending_messages"] - + @pending_messages.setter def pending_messages(self, new_pending_messages): self.settings["pending_messages"] = new_pending_messages @@ -186,10 +185,14 @@ def open(self): self.root_chat_handlers[client_id] = self self.chat_clients[client_id] = ChatClient(**current_user, id=client_id) self.client_id = client_id - self.write_message(ConnectionMessage( - client_id=client_id, - history=ChatHistory(messages=self.chat_history, pending_messages=self.pending_messages) - ).dict()) + self.write_message( + ConnectionMessage( + client_id=client_id, + history=ChatHistory( + messages=self.chat_history, pending_messages=self.pending_messages + ), + ).dict() + ) self.log.info(f"Client connected. ID: {client_id}") self.log.debug("Clients are : %s", self.root_chat_handlers.keys()) @@ -208,7 +211,9 @@ def broadcast_message(self, message: Message): client.write_message(message.dict()) # append all messages of type `ChatMessage` directly to the chat history - if isinstance(message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage)): + if isinstance( + message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage) + ): self.chat_history.append(message) elif isinstance(message, AgentStreamChunkMessage): # for stream chunks, modify the corresponding `AgentStreamMessage` @@ -217,7 +222,10 @@ def broadcast_message(self, message: Message): # iterate backwards from the end of the list for i in range(len(self.chat_history) - 1, -1, -1): - if self.chat_history[i].type == 'agent-stream' and self.chat_history[i].id == chunk.id: + if ( + self.chat_history[i].type == "agent-stream" + and self.chat_history[i].id == chunk.id + ): stream_message: AgentStreamMessage = self.chat_history[i] stream_message.body += chunk.content stream_message.complete = chunk.stream_complete @@ -225,8 +233,9 @@ def broadcast_message(self, message: Message): elif isinstance(message, PendingMessage): self.pending_messages.append(message) elif isinstance(message, ClosePendingMessage): - self.pending_messages = list(filter(lambda m: m.id != message.id, self.pending_messages)) - + self.pending_messages = list( + filter(lambda m: m.id != message.id, self.pending_messages) + ) async def on_message(self, message): self.log.debug("Message received: %s", message) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 42f1cbf1e..02b77b911 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,7 +1,9 @@ -from langchain_core.messages import BaseMessage +from typing import List, Sequence + from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage from langchain_core.pydantic_v1 import BaseModel, Field -from typing import List, Sequence + class BoundedChatHistory(BaseChatMessageHistory, BaseModel): """ diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index d59451d76..ea541568b 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -68,13 +68,15 @@ class AgentChatMessage(BaseModel): this defaults to a description of `JupyternautPersona`. """ + class AgentStreamMessage(AgentChatMessage): - type: Literal['agent-stream'] = 'agent-stream' + type: Literal["agent-stream"] = "agent-stream" complete: bool # other attrs inherited from `AgentChatMessage` + class AgentStreamChunkMessage(BaseModel): - type: Literal['agent-stream-chunk'] = 'agent-stream-chunk' + type: Literal["agent-stream-chunk"] = "agent-stream-chunk" id: str content: str stream_complete: bool @@ -118,6 +120,7 @@ class ClosePendingMessage(BaseModel): class ChatHistory(BaseModel): """History of chat messages""" + messages: List[ChatMessage] pending_messages: List[PendingMessage]