diff --git a/.gitignore b/.gitignore index 7bcadb5c8..0f1f752b4 100644 --- a/.gitignore +++ b/.gitignore @@ -135,5 +135,6 @@ dev.sh .conda/ -# reserved for testing cookiecutter -packages/jupyter-ai-test +# Version files are auto-generated by Hatchling and should not be committed to +# the source repo. +packages/**/_version.py diff --git a/packages/jupyter-ai-magics/.gitignore b/packages/jupyter-ai-magics/.gitignore index 236c7b23c..77ab48f38 100644 --- a/packages/jupyter-ai-magics/.gitignore +++ b/packages/jupyter-ai-magics/.gitignore @@ -8,8 +8,6 @@ node_modules/ .ipynb_checkpoints *.tsbuildinfo jupyter_ai_magics/labextension -# Version file is handled by hatchling -jupyter_ai_magics/_version.py # Integration tests ui-tests/test-results/ diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 0dd02b343..8c8a1f917 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -19,7 +19,6 @@ ) from jsonpath_ng import parse -from langchain.chat_models.base import BaseChatModel from langchain.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -42,6 +41,8 @@ Together, ) from langchain_community.llms.sagemaker_endpoint import LLMContentHandler +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM # this is necessary because `langchain.pydantic_v1.main` does not include # `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main` @@ -448,6 +449,24 @@ def is_chat_provider(self): def allows_concurrency(self): return True + @property + def _supports_sync_streaming(self): + if self.is_chat_provider: + return not (self.__class__._stream is BaseChatModel._stream) + else: + return not (self.__class__._stream is BaseLLM._stream) + + @property + def _supports_async_streaming(self): + if self.is_chat_provider: + return not (self.__class__._astream is BaseChatModel._astream) + else: + return not (self.__class__._astream is BaseLLM._astream) + + @property + def supports_streaming(self): + return self._supports_sync_streaming or self._supports_async_streaming + async def generate_inline_completions( self, request: InlineCompletionRequest ) -> InlineCompletionReply: diff --git a/packages/jupyter-ai-test/LICENSE b/packages/jupyter-ai-test/LICENSE new file mode 100644 index 000000000..eb0d24e83 --- /dev/null +++ b/packages/jupyter-ai-test/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2024, Project Jupyter +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/packages/jupyter-ai-test/README.md b/packages/jupyter-ai-test/README.md new file mode 100644 index 000000000..1d0b31613 --- /dev/null +++ b/packages/jupyter-ai-test/README.md @@ -0,0 +1,58 @@ +# jupyter_ai_test + +`jupyter_ai_test` is a Jupyter AI module that registers additional model +providers and slash commands for testing Jupyter AI in a local development +environment. This package should never published on NPM or PyPI. + +## Requirements + +- Python 3.8 - 3.11 +- JupyterLab 4 + +## Install + +To install the extension, execute: + +```bash +pip install jupyter_ai_test +``` + +## Uninstall + +To remove the extension, execute: + +```bash +pip uninstall jupyter_ai_test +``` + +## Contributing + +### Development install + +```bash +cd jupyter-ai-test +pip install -e "." +``` + +### Development uninstall + +```bash +pip uninstall jupyter_ai_test +``` + +#### Backend tests + +This package uses [Pytest](https://docs.pytest.org/) for Python testing. + +Install test dependencies (needed only once): + +```sh +cd jupyter-ai-test +pip install -e ".[test]" +``` + +To execute them, run: + +```sh +pytest -vv -r ap --cov jupyter_ai_test +``` diff --git a/packages/jupyter-ai-test/jupyter_ai_test/__init__.py b/packages/jupyter-ai-test/jupyter_ai_test/__init__.py new file mode 100644 index 000000000..8dee4bf82 --- /dev/null +++ b/packages/jupyter-ai-test/jupyter_ai_test/__init__.py @@ -0,0 +1 @@ +from ._version import __version__ diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py new file mode 100644 index 000000000..c7c72666b --- /dev/null +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -0,0 +1,57 @@ +import time +from typing import Any, Iterator, List, Optional + +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.outputs.generation import GenerationChunk + + +class TestLLM(LLM): + model_id: str = "test" + + @property + def _llm_type(self) -> str: + return "custom" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + time.sleep(3) + return f"Hello! This is a dummy response from a test LLM." + + +class TestLLMWithStreaming(LLM): + model_id: str = "test" + + @property + def _llm_type(self) -> str: + return "custom" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + time.sleep(3) + return f"Hello! This is a dummy response from a test LLM." + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + time.sleep(5) + yield GenerationChunk( + text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n" + ) + for i in range(1, 101): + time.sleep(0.5) + yield GenerationChunk(text=f"{i}, ") diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py new file mode 100644 index 000000000..f2803deec --- /dev/null +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py @@ -0,0 +1,77 @@ +from typing import ClassVar, List + +from jupyter_ai import AuthStrategy, BaseProvider, Field + +from .test_llms import TestLLM, TestLLMWithStreaming + + +class TestProvider(BaseProvider, TestLLM): + id: ClassVar[str] = "test-provider" + """ID for this provider class.""" + + name: ClassVar[str] = "Test Provider" + """User-facing name of this provider.""" + + models: ClassVar[List[str]] = ["test"] + """List of supported models by their IDs. For registry providers, this will + be just ["*"].""" + + help: ClassVar[str] = None + """Text to display in lieu of a model list for a registry provider that does + not provide a list of models.""" + + model_id_key: ClassVar[str] = "model_id" + """Kwarg expected by the upstream LangChain provider.""" + + model_id_label: ClassVar[str] = "Model ID" + """Human-readable label of the model ID.""" + + pypi_package_deps: ClassVar[List[str]] = [] + """List of PyPi package dependencies.""" + + auth_strategy: ClassVar[AuthStrategy] = None + """Authentication/authorization strategy. Declares what credentials are + required to use this model provider. Generally should not be `None`.""" + + registry: ClassVar[bool] = False + """Whether this provider is a registry provider.""" + + fields: ClassVar[List[Field]] = [] + """User inputs expected by this provider when initializing it. Each `Field` `f` + should be passed in the constructor as a keyword argument, keyed by `f.key`.""" + + +class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming): + id: ClassVar[str] = "test-provider-with-streaming" + """ID for this provider class.""" + + name: ClassVar[str] = "Test Provider (streaming)" + """User-facing name of this provider.""" + + models: ClassVar[List[str]] = ["test"] + """List of supported models by their IDs. For registry providers, this will + be just ["*"].""" + + help: ClassVar[str] = None + """Text to display in lieu of a model list for a registry provider that does + not provide a list of models.""" + + model_id_key: ClassVar[str] = "model_id" + """Kwarg expected by the upstream LangChain provider.""" + + model_id_label: ClassVar[str] = "Model ID" + """Human-readable label of the model ID.""" + + pypi_package_deps: ClassVar[List[str]] = [] + """List of PyPi package dependencies.""" + + auth_strategy: ClassVar[AuthStrategy] = None + """Authentication/authorization strategy. Declares what credentials are + required to use this model provider. Generally should not be `None`.""" + + registry: ClassVar[bool] = False + """Whether this provider is a registry provider.""" + + fields: ClassVar[List[Field]] = [] + """User inputs expected by this provider when initializing it. Each `Field` `f` + should be passed in the constructor as a keyword argument, keyed by `f.key`.""" diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py new file mode 100644 index 000000000..f82bd5531 --- /dev/null +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py @@ -0,0 +1,29 @@ +from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType +from jupyter_ai.models import HumanChatMessage + + +class TestSlashCommand(BaseChatHandler): + """ + A test slash command implementation that developers should build from. The + string used to invoke this command is set by the `slash_id` keyword argument + in the `routing_type` attribute. The command is mainly implemented in the + `process_message()` method. See built-in implementations under + `jupyter_ai/handlers` for further reference. + + The provider is made available to Jupyter AI by the entry point declared in + `pyproject.toml`. If this class or parent module is renamed, make sure the + update the entry point there as well. + """ + + id = "test" + name = "Test" + help = "A test slash command." + routing_type = SlashCommandRoutingType(slash_id="test") + + uses_llm = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def process_message(self, message: HumanChatMessage): + self.reply("This is the `/test` slash command.") diff --git a/packages/jupyter-ai-test/jupyter_ai_test/tests/__init__.py b/packages/jupyter-ai-test/jupyter_ai_test/tests/__init__.py new file mode 100644 index 000000000..bfecc1ade --- /dev/null +++ b/packages/jupyter-ai-test/jupyter_ai_test/tests/__init__.py @@ -0,0 +1 @@ +"""Python unit tests for jupyter_ai_test.""" diff --git a/packages/jupyter-ai-test/package.json b/packages/jupyter-ai-test/package.json new file mode 100644 index 000000000..92d5d07b7 --- /dev/null +++ b/packages/jupyter-ai-test/package.json @@ -0,0 +1,25 @@ +{ + "name": "@jupyter-ai/test", + "version": "2.18.1", + "description": "Jupyter AI test package. Not published on NPM or PyPI.", + "private": true, + "homepage": "https://github.com/jupyterlab/jupyter-ai", + "bugs": { + "url": "https://github.com/jupyterlab/jupyter-ai/issues", + "email": "jupyter@googlegroups.com" + }, + "license": "BSD-3-Clause", + "author": { + "name": "Project Jupyter", + "email": "jupyter@googlegroups.com" + }, + "repository": { + "type": "git", + "url": "https://github.com/jupyterlab/jupyter-ai.git" + }, + "scripts": { + "dev-install": "pip install -e .", + "dev-uninstall": "pip uninstall jupyter_ai_test -y", + "install-from-src": "pip install ." + } +} diff --git a/packages/jupyter-ai-test/project.json b/packages/jupyter-ai-test/project.json new file mode 100644 index 000000000..9af45c206 --- /dev/null +++ b/packages/jupyter-ai-test/project.json @@ -0,0 +1,4 @@ +{ + "name": "@jupyter-ai/test", + "implicitDependencies": ["@jupyter-ai/core"] +} diff --git a/packages/jupyter-ai-test/pyproject.toml b/packages/jupyter-ai-test/pyproject.toml new file mode 100644 index 000000000..4401ff736 --- /dev/null +++ b/packages/jupyter-ai-test/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["hatchling>=1.4.0", "jupyterlab~=4.0"] +build-backend = "hatchling.build" + +[project] +name = "jupyter_ai_test" +readme = "README.md" +license = { file = "LICENSE" } +requires-python = ">=3.8" +classifiers = [ + "Framework :: Jupyter", + "Framework :: Jupyter :: JupyterLab", + "Framework :: Jupyter :: JupyterLab :: 4", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +version = "0.1.0" +description = "A Jupyter AI extension." +authors = [{ name = "Project Jupyter", email = "jupyter@googlegroups.com" }] +dependencies = ["jupyter_ai"] + +[project.optional-dependencies] +test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"] + +[project.entry-points."jupyter_ai.model_providers"] +test-provider = "jupyter_ai_test.test_providers:TestProvider" +test-provider-with-streaming = "jupyter_ai_test.test_providers:TestProviderWithStreaming" + +[project.entry-points."jupyter_ai.chat_handlers"] +test-slash-command = "jupyter_ai_test.test_slash_commands:TestSlashCommand" + +[tool.hatch.build.hooks.version] +path = "jupyter_ai_test/_version.py" + +[tool.check-wheel-contents] +ignore = ["W002"] diff --git a/packages/jupyter-ai-test/setup.py b/packages/jupyter-ai-test/setup.py new file mode 100644 index 000000000..aefdf20db --- /dev/null +++ b/packages/jupyter-ai-test/setup.py @@ -0,0 +1 @@ +__import__("setuptools").setup() diff --git a/packages/jupyter-ai/.gitignore b/packages/jupyter-ai/.gitignore index 48d71c8b7..c919aa0c7 100644 --- a/packages/jupyter-ai/.gitignore +++ b/packages/jupyter-ai/.gitignore @@ -7,8 +7,6 @@ node_modules/ .ipynb_checkpoints *.tsbuildinfo jupyter_ai/labextension -# Version file is handled by hatchling -jupyter_ai/_version.py # Integration tests ui-tests/test-results/ diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 97392168a..fd412eed7 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -183,14 +183,12 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): Sends an agent message, usually in response to a received `HumanChatMessage`. """ - persona = self.config_manager.persona - agent_msg = AgentChatMessage( id=uuid4().hex, time=time.time(), body=response, reply_to=human_msg.id if human_msg else "", - persona=Persona(name=persona.name, avatar_route=persona.avatar_route), + persona=self.persona, ) for handler in self._root_chat_handlers.values(): @@ -200,7 +198,11 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): handler.broadcast_message(agent_msg) break - def start_pending(self, text: str, ellipsis: bool = True) -> str: + @property + def persona(self): + return self.config_manager.persona + + def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage: """ Sends a pending message to the client. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 75c5e6023..49b48fe42 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,10 +1,17 @@ -from typing import Dict, List, Type +import time +from typing import Dict, Type +from uuid import uuid4 -from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage +from jupyter_ai.models import ( + AgentStreamChunkMessage, + AgentStreamMessage, + HumanChatMessage, +) from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import ConversationChain, LLMChain -from langchain.memory import ConversationBufferWindowMemory +from langchain_core.messages import AIMessageChunk +from langchain_core.runnables.history import RunnableWithMessageHistory +from ..history import BoundedChatHistory from .base import BaseChatHandler, SlashCommandRoutingType @@ -18,12 +25,12 @@ class DefaultChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.memory = ConversationBufferWindowMemory(return_messages=True, k=2) def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): unified_parameters = { + "verbose": True, **provider_params, **(self.get_model_parameters(provider, provider_params)), } @@ -31,22 +38,87 @@ def create_llm_chain( prompt_template = llm.get_chat_prompt_template() self.llm = llm - self.memory = ConversationBufferWindowMemory( - return_messages=llm.is_chat_provider, k=2 + + runnable = prompt_template | llm + if not llm.manages_history: + history = BoundedChatHistory(k=2) + runnable = RunnableWithMessageHistory( + runnable=runnable, + get_session_history=lambda *args: history, + input_messages_key="input", + history_messages_key="history", + ) + + self.llm_chain = runnable + + def _start_stream(self, human_msg: HumanChatMessage) -> str: + """ + Sends an `agent-stream` message to indicate the start of a response + stream. Returns the ID of the message, denoted as the `stream_id`. + """ + stream_id = uuid4().hex + stream_msg = AgentStreamMessage( + id=stream_id, + time=time.time(), + body="", + reply_to=human_msg.id, + persona=self.persona, + complete=False, ) - if llm.manages_history: - self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True) + for handler in self._root_chat_handlers.values(): + if not handler: + continue - else: - self.llm_chain = ConversationChain( - llm=llm, prompt=prompt_template, verbose=True, memory=self.memory - ) + handler.broadcast_message(stream_msg) + break + + return stream_id + + def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False): + """ + Sends an `agent-stream-chunk` message containing content that should be + appended to an existing `agent-stream` message with ID `stream_id`. + """ + stream_chunk_msg = AgentStreamChunkMessage( + id=stream_id, content=content, stream_complete=complete + ) + + for handler in self._root_chat_handlers.values(): + if not handler: + continue + + handler.broadcast_message(stream_chunk_msg) + break async def process_message(self, message: HumanChatMessage): self.get_llm_chain() - with self.pending("Generating response"): - response = await self.llm_chain.apredict( - input=message.body, stop=["\nHuman:"] - ) - self.reply(response, message) + received_first_chunk = False + + # start with a pending message + pending_message = self.start_pending("Generating response") + + # stream response in chunks. this works even if a provider does not + # implement streaming, as `astream()` defaults to yielding `_call()` + # when `_stream()` is not implemented on the LLM class. + async for chunk in self.llm_chain.astream( + {"input": message.body}, + config={"configurable": {"session_id": "static_session"}}, + ): + if not received_first_chunk: + # when receiving the first chunk, close the pending message and + # start the stream. + self.close_pending(pending_message) + stream_id = self._start_stream(human_msg=message) + received_first_chunk = True + + if isinstance(chunk, AIMessageChunk): + self._send_stream_chunk(stream_id, chunk.content) + elif isinstance(chunk, str): + self._send_stream_chunk(stream_id, chunk) + else: + self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") + break + + # complete stream after all chunks have been streamed + self._send_stream_chunk(stream_id, "", complete=True) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 848a6eed7..f97ff9ee9 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -222,6 +222,9 @@ def initialize_settings(self): # memory object used by the LM chain. self.settings["chat_history"] = [] + # list of pending messages + self.settings["pending_messages"] = [] + # get reference to event loop # `asyncio.get_event_loop()` is deprecated in Python 3.11+, in favor of # the more readable `asyncio.get_event_loop_policy().get_event_loop()`. diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index a2ef152cb..9f3fcee71 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -17,11 +17,14 @@ from .models import ( AgentChatMessage, + AgentStreamChunkMessage, + AgentStreamMessage, ChatClient, ChatHistory, ChatMessage, ChatRequest, ChatUser, + ClosePendingMessage, ConnectionMessage, HumanChatMessage, ListProvidersEntry, @@ -29,6 +32,7 @@ ListSlashCommandsEntry, ListSlashCommandsResponse, Message, + PendingMessage, UpdateConfigRequest, ) @@ -43,16 +47,18 @@ class ChatHistoryHandler(BaseAPIHandler): _messages = [] @property - def chat_history(self): + def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] - @chat_history.setter - def _chat_history_setter(self, new_history): - self.settings["chat_history"] = new_history + @property + def pending_messages(self) -> List[PendingMessage]: + return self.settings["pending_messages"] @tornado.web.authenticated async def get(self): - history = ChatHistory(messages=self.chat_history) + history = ChatHistory( + messages=self.chat_history, pending_messages=self.pending_messages + ) self.finish(history.json()) @@ -88,10 +94,22 @@ def chat_client(self) -> ChatClient: def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] + @chat_history.setter + def chat_history(self, new_history): + self.settings["chat_history"] = new_history + @property def loop(self) -> AbstractEventLoop: return self.settings["jai_event_loop"] + @property + def pending_messages(self) -> List[PendingMessage]: + return self.settings["pending_messages"] + + @pending_messages.setter + def pending_messages(self, new_pending_messages): + self.settings["pending_messages"] = new_pending_messages + def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) @@ -167,7 +185,14 @@ def open(self): self.root_chat_handlers[client_id] = self self.chat_clients[client_id] = ChatClient(**current_user, id=client_id) self.client_id = client_id - self.write_message(ConnectionMessage(client_id=client_id).dict()) + self.write_message( + ConnectionMessage( + client_id=client_id, + history=ChatHistory( + messages=self.chat_history, pending_messages=self.pending_messages + ), + ).dict() + ) self.log.info(f"Client connected. ID: {client_id}") self.log.debug("Clients are : %s", self.root_chat_handlers.keys()) @@ -185,11 +210,32 @@ def broadcast_message(self, message: Message): if client: client.write_message(message.dict()) - # Only append ChatMessage instances to history, not control messages - if isinstance(message, HumanChatMessage) or isinstance( - message, AgentChatMessage + # append all messages of type `ChatMessage` directly to the chat history + if isinstance( + message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage) ): self.chat_history.append(message) + elif isinstance(message, AgentStreamChunkMessage): + # for stream chunks, modify the corresponding `AgentStreamMessage` + # by appending its content and potentially marking it as complete. + chunk: AgentStreamChunkMessage = message + + # iterate backwards from the end of the list + for history_message in self.chat_history[::-1]: + if ( + history_message.type == "agent-stream" + and history_message.id == chunk.id + ): + stream_message: AgentStreamMessage = history_message + stream_message.body += chunk.content + stream_message.complete = chunk.stream_complete + break + elif isinstance(message, PendingMessage): + self.pending_messages.append(message) + elif isinstance(message, ClosePendingMessage): + self.pending_messages = list( + filter(lambda m: m.id != message.id, self.pending_messages) + ) async def on_message(self, message): self.log.debug("Message received: %s", message) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py new file mode 100644 index 000000000..02b77b911 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -0,0 +1,40 @@ +from typing import List, Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage +from langchain_core.pydantic_v1 import BaseModel, Field + + +class BoundedChatHistory(BaseChatMessageHistory, BaseModel): + """ + An in-memory implementation of `BaseChatMessageHistory` that stores up to + `k` exchanges between a user and an LLM. + + For example, when `k=2`, `BoundedChatHistory` will store up to 2 human + messages and 2 AI messages. + """ + + messages: List[BaseMessage] = Field(default_factory=list) + size: int = 0 + k: int + + async def aget_messages(self) -> List[BaseMessage]: + return self.messages + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + self.messages.append(message) + self.size += 1 + + if self.size > self.k * 2: + self.messages.pop(0) + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Add messages to the store""" + self.add_messages(messages) + + def clear(self) -> None: + self.messages = [] + + async def aclear(self) -> None: + self.clear() diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 9e269a223..ea541568b 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -69,6 +69,20 @@ class AgentChatMessage(BaseModel): """ +class AgentStreamMessage(AgentChatMessage): + type: Literal["agent-stream"] = "agent-stream" + complete: bool + # other attrs inherited from `AgentChatMessage` + + +class AgentStreamChunkMessage(BaseModel): + type: Literal["agent-stream-chunk"] = "agent-stream-chunk" + id: str + content: str + stream_complete: bool + """Indicates whether this chunk message completes the referenced stream.""" + + class HumanChatMessage(BaseModel): type: Literal["human"] = "human" id: str @@ -78,11 +92,6 @@ class HumanChatMessage(BaseModel): selection: Optional[Selection] -class ConnectionMessage(BaseModel): - type: Literal["connection"] = "connection" - client_id: str - - class ClearMessage(BaseModel): type: Literal["clear"] = "clear" @@ -105,8 +114,23 @@ class ClosePendingMessage(BaseModel): ChatMessage = Union[ AgentChatMessage, HumanChatMessage, + AgentStreamMessage, ] + +class ChatHistory(BaseModel): + """History of chat messages""" + + messages: List[ChatMessage] + pending_messages: List[PendingMessage] + + +class ConnectionMessage(BaseModel): + type: Literal["connection"] = "connection" + client_id: str + history: ChatHistory + + Message = Union[ AgentChatMessage, HumanChatMessage, @@ -117,12 +141,6 @@ class ClosePendingMessage(BaseModel): ] -class ChatHistory(BaseModel): - """History of chat messages""" - - messages: List[ChatMessage] - - class ListProvidersEntry(BaseModel): """Model provider with supported models and provider's authentication strategy diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 4328c9aca..15e660cfc 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "importlib_metadata>=5.2.0", "jupyter_ai_magics>=2.13.0", "dask[distributed]", - "faiss-cpu", # Not distributed by official repo + "faiss-cpu<=1.8.0", # Not distributed by official repo "typing_extensions>=4.5.0", "traitlets>=5.0", "deepmerge>=1.0", diff --git a/packages/jupyter-ai/src/chat_handler.ts b/packages/jupyter-ai/src/chat_handler.ts index fb54df519..f1b131dcf 100644 --- a/packages/jupyter-ai/src/chat_handler.ts +++ b/packages/jupyter-ai/src/chat_handler.ts @@ -1,7 +1,9 @@ import { IDisposable } from '@lumino/disposable'; import { ServerConnection } from '@jupyterlab/services'; import { URLExt } from '@jupyterlab/coreutils'; -import { AiService, requestAPI } from './handler'; +import { Signal } from '@lumino/signaling'; + +import { AiService } from './handler'; const CHAT_SERVICE_URL = 'api/ai/chats'; @@ -65,18 +67,6 @@ export class ChatHandler implements IDisposable { } } - public async getHistory(): Promise { - let data: AiService.ChatHistory = { messages: [] }; - try { - data = await requestAPI('chats/history', { - method: 'GET' - }); - } catch (e) { - return Promise.reject(e); - } - return data; - } - /** * Whether the chat handler is disposed. */ @@ -106,23 +96,84 @@ export class ChatHandler implements IDisposable { } } - private _onMessage(message: AiService.Message): void { + get history(): AiService.ChatHistory { + return { + messages: this._messages, + pending_messages: this._pendingMessages + }; + } + + get historyChanged(): Signal { + return this._historyChanged; + } + + private _onMessage(newMessage: AiService.Message): void { // resolve promise from `sendMessage()` - if (message.type === 'human' && message.client.id === this.id) { - this._sendResolverQueue.shift()?.(message.id); + if (newMessage.type === 'human' && newMessage.client.id === this.id) { + this._sendResolverQueue.shift()?.(newMessage.id); } // resolve promise from `replyFor()` if it exists if ( - message.type === 'agent' && - message.reply_to in this._replyForResolverDict + newMessage.type === 'agent' && + newMessage.reply_to in this._replyForResolverDict ) { - this._replyForResolverDict[message.reply_to](message); - delete this._replyForResolverDict[message.reply_to]; + this._replyForResolverDict[newMessage.reply_to](newMessage); + delete this._replyForResolverDict[newMessage.reply_to]; } // call listeners in serial - this._listeners.forEach(listener => listener(message)); + this._listeners.forEach(listener => listener(newMessage)); + + // append message to chat history. this block should always set `_messages` + // or `_pendingMessages` to a new array instance rather than modifying + // in-place so consumer React components re-render. + switch (newMessage.type) { + case 'connection': + break; + case 'clear': + this._messages = []; + break; + case 'pending': + this._pendingMessages = [...this._pendingMessages, newMessage]; + break; + case 'close-pending': + this._pendingMessages = this._pendingMessages.filter( + p => p.id !== newMessage.id + ); + break; + case 'agent-stream-chunk': { + const target = newMessage.id; + const streamMessage = this._messages.find( + (m): m is AiService.AgentStreamMessage => + m.type === 'agent-stream' && m.id === target + ); + if (!streamMessage) { + console.error( + `Received stream chunk with ID ${target}, but no agent-stream message with that ID exists. ` + + 'Ignoring this stream chunk.' + ); + break; + } + + streamMessage.body += newMessage.content; + if (newMessage.stream_complete) { + streamMessage.complete = true; + } + this._messages = [...this._messages]; + break; + } + default: + // human or agent chat message + this._messages = [...this._messages, newMessage]; + break; + } + + // finally, trigger `historyChanged` signal + this._historyChanged.emit({ + messages: this._messages, + pending_messages: this._pendingMessages + }); } /** @@ -173,6 +224,11 @@ export class ChatHandler implements IDisposable { return; } this.id = message.client_id; + + // initialize chat history from `ConnectionMessage` + this._messages = message.history.messages; + this._pendingMessages = message.history.pending_messages; + resolve(); this.removeListener(listenForConnection); }; @@ -184,4 +240,17 @@ export class ChatHandler implements IDisposable { private _isDisposed = false; private _socket: WebSocket | null = null; private _listeners: ((msg: any) => void)[] = []; + + /** + * The list of chat messages + */ + private _messages: AiService.ChatMessage[] = []; + private _pendingMessages: AiService.PendingMessage[] = []; + + /** + * Signal for when the chat history is changed. Components rendering the chat + * history should subscribe to this signal and update their state when this + * signal is triggered. + */ + private _historyChanged = new Signal(this); } diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index 77fbbbfd1..e6cef48ec 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -203,6 +203,9 @@ export function ChatMessages(props: ChatMessagesProps): JSX.Element { ); diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index 87009ecf5..c23027ca0 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -354,6 +354,7 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { )} {lmGlobalId && ( @@ -448,6 +449,7 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { )} {clmGlobalId && ( diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index 4c232ed5b..c84ae022b 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -38,10 +38,12 @@ function ChatBody({ setChatView: chatViewHandler, rmRegistry: renderMimeRegistry }: ChatBodyProps): JSX.Element { - const [messages, setMessages] = useState([]); + const [messages, setMessages] = useState([ + ...chatHandler.history.messages + ]); const [pendingMessages, setPendingMessages] = useState< AiService.PendingMessage[] - >([]); + >([...chatHandler.history.pending_messages]); const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [includeSelection, setIncludeSelection] = useState(true); const [replaceSelection, setReplaceSelection] = useState(false); @@ -50,17 +52,13 @@ function ChatBody({ const [sendWithShiftEnter, setSendWithShiftEnter] = useState(true); /** - * Effect: fetch history and config on initial render + * Effect: fetch config on initial render */ useEffect(() => { - async function fetchHistory() { + async function fetchConfig() { try { - const [history, config] = await Promise.all([ - chatHandler.getHistory(), - AiService.getConfig() - ]); + const config = await AiService.getConfig(); setSendWithShiftEnter(config.send_with_shift_enter ?? false); - setMessages(history.messages); if (!config.model_provider_id) { setShowWelcomeMessage(true); } @@ -69,37 +67,22 @@ function ChatBody({ } } - fetchHistory(); + fetchConfig(); }, [chatHandler]); /** * Effect: listen to chat messages */ useEffect(() => { - function handleChatEvents(message: AiService.Message) { - switch (message.type) { - case 'connection': - return; - case 'clear': - setMessages([]); - return; - case 'pending': - setPendingMessages(pendingMessages => [...pendingMessages, message]); - return; - case 'close-pending': - setPendingMessages(pendingMessages => - pendingMessages.filter(p => p.id !== message.id) - ); - return; - default: - setMessages(messageGroups => [...messageGroups, message]); - return; - } + function onHistoryChange(_: unknown, history: AiService.ChatHistory) { + setMessages([...history.messages]); + setPendingMessages([...history.pending_messages]); } - chatHandler.addListener(handleChatEvents); + chatHandler.historyChanged.connect(onHistoryChange); + return function cleanup() { - chatHandler.removeListener(handleChatEvents); + chatHandler.historyChanged.disconnect(onHistoryChange); }; }, [chatHandler]); diff --git a/packages/jupyter-ai/src/components/rendermime-markdown.tsx b/packages/jupyter-ai/src/components/rendermime-markdown.tsx index 570a43e42..0e1ca8e9a 100644 --- a/packages/jupyter-ai/src/components/rendermime-markdown.tsx +++ b/packages/jupyter-ai/src/components/rendermime-markdown.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from 'react'; +import React, { useState, useEffect, useRef } from 'react'; import { createPortal } from 'react-dom'; import { CodeToolbar, CodeToolbarProps } from './code-blocks/code-toolbar'; @@ -10,6 +10,12 @@ const RENDERMIME_MD_CLASS = 'jp-ai-rendermime-markdown'; type RendermimeMarkdownProps = { markdownStr: string; rmRegistry: IRenderMimeRegistry; + /** + * Whether the message is complete. This is generally `true` except in the + * case where `markdownStr` contains the incomplete contents of a + * `AgentStreamMessage`, in which case this should be set to `false`. + */ + complete: boolean; }; /** @@ -24,15 +30,26 @@ function escapeLatexDelimiters(text: string) { } function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element { - const [renderedContent, setRenderedContent] = useState( - null - ); + // create a single renderer object at component mount + const [renderer] = useState(() => { + return props.rmRegistry.createRenderer(MD_MIME_TYPE); + }); + + // ref that tracks the content container to store the rendermime node in + const renderingContainer = useRef(null); + // ref that tracks whether the rendermime node has already been inserted + const renderingInserted = useRef(false); // each element is a two-tuple with the structure [codeToolbarRoot, codeToolbarProps]. const [codeToolbarDefns, setCodeToolbarDefns] = useState< Array<[HTMLDivElement, CodeToolbarProps]> >([]); + /** + * Effect: use Rendermime to render `props.markdownStr` into an HTML element, + * and insert it into `renderingContainer` if not yet inserted. When the + * message is completed, add code toolbars. + */ useEffect(() => { const renderContent = async () => { const mdStr = escapeLatexDelimiters(props.markdownStr); @@ -40,7 +57,6 @@ function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element { data: { [MD_MIME_TYPE]: mdStr } }); - const renderer = props.rmRegistry.createRenderer(MD_MIME_TYPE); await renderer.renderModel(model); props.rmRegistry.latexTypesetter?.typeset(renderer.node); if (!renderer.node) { @@ -49,6 +65,16 @@ function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element { ); } + // insert the rendering into renderingContainer if not yet inserted + if (renderingContainer.current !== null && !renderingInserted.current) { + renderingContainer.current.appendChild(renderer.node); + renderingInserted.current = true; + } + + // if complete, render code toolbars + if (!props.complete) { + return; + } const newCodeToolbarDefns: [HTMLDivElement, CodeToolbarProps][] = []; // Attach CodeToolbar root element to each
 block
@@ -66,17 +92,14 @@ function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element {
       });
 
       setCodeToolbarDefns(newCodeToolbarDefns);
-      setRenderedContent(renderer.node);
     };
 
     renderContent();
-  }, [props.markdownStr, props.rmRegistry]);
+  }, [props.markdownStr, props.complete, props.rmRegistry]);
 
   return (
     
- {renderedContent && ( -
node && node.appendChild(renderedContent)} /> - )} +
{ // Render a `CodeToolbar` element underneath each code block. // We use ReactDOM.createPortal() so each `CodeToolbar` element is able diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index c8f457fe2..b93a4571a 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -109,6 +109,7 @@ export namespace AiService { export type ConnectionMessage = { type: 'connection'; client_id: string; + history: ChatHistory; }; export type ClearMessage = { @@ -129,17 +130,36 @@ export namespace AiService { id: string; }; - export type ChatMessage = AgentChatMessage | HumanChatMessage; + export type AgentStreamMessage = Omit & { + type: 'agent-stream'; + complete: boolean; + }; + + export type AgentStreamChunkMessage = { + type: 'agent-stream-chunk'; + id: string; + content: string; + stream_complete: boolean; + }; + + export type ChatMessage = + | AgentChatMessage + | HumanChatMessage + | AgentStreamMessage; + export type Message = | AgentChatMessage | HumanChatMessage | ConnectionMessage | ClearMessage | PendingMessage - | ClosePendingMessage; + | ClosePendingMessage + | AgentStreamMessage + | AgentStreamChunkMessage; export type ChatHistory = { messages: ChatMessage[]; + pending_messages: PendingMessage[]; }; export type DescribeConfigResponse = { diff --git a/yarn.lock b/yarn.lock index 710e780ca..76ddf97ca 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2279,6 +2279,12 @@ __metadata: languageName: unknown linkType: soft +"@jupyter-ai/test@workspace:packages/jupyter-ai-test": + version: 0.0.0-use.local + resolution: "@jupyter-ai/test@workspace:packages/jupyter-ai-test" + languageName: unknown + linkType: soft + "@jupyter/collaboration@npm:^1": version: 1.2.1 resolution: "@jupyter/collaboration@npm:1.2.1"