Skip to content

Commit

Permalink
Implement LLM response streaming (jupyterlab#859)
Browse files Browse the repository at this point in the history
* minimal implementation of chat streaming

* improve chat history handling

- ensures users never miss streamed chunks when joining
- also removes temporary print/log statements introduced in prev commit

* add jupyter_ai_test package for developer testing

* pre-commit

* improve readability of for loop finding stream msg

Co-authored-by: Piyush Jain <[email protected]>

* remove _version.py

* remove unused ConversationBufferWindowMemory

* update jupyter_ai_test README

* add _version.py files to top-level .gitignore

* pre-commit

---------

Co-authored-by: Piyush Jain <[email protected]>
  • Loading branch information
2 people authored and Marchlak committed Oct 28, 2024
1 parent 18b7b6d commit f5382ae
Show file tree
Hide file tree
Showing 29 changed files with 740 additions and 114 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,6 @@ dev.sh

.conda/

# reserved for testing cookiecutter
packages/jupyter-ai-test
# Version files are auto-generated by Hatchling and should not be committed to
# the source repo.
packages/**/_version.py
2 changes: 0 additions & 2 deletions packages/jupyter-ai-magics/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ node_modules/
.ipynb_checkpoints
*.tsbuildinfo
jupyter_ai_magics/labextension
# Version file is handled by hatchling
jupyter_ai_magics/_version.py

# Integration tests
ui-tests/test-results/
Expand Down
21 changes: 20 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)

from jsonpath_ng import parse
from langchain.chat_models.base import BaseChatModel
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
Expand All @@ -42,6 +41,8 @@
Together,
)
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

# this is necessary because `langchain.pydantic_v1.main` does not include
# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main`
Expand Down Expand Up @@ -448,6 +449,24 @@ def is_chat_provider(self):
def allows_concurrency(self):
return True

@property
def _supports_sync_streaming(self):
if self.is_chat_provider:
return not (self.__class__._stream is BaseChatModel._stream)
else:
return not (self.__class__._stream is BaseLLM._stream)

@property
def _supports_async_streaming(self):
if self.is_chat_provider:
return not (self.__class__._astream is BaseChatModel._astream)
else:
return not (self.__class__._astream is BaseLLM._astream)

@property
def supports_streaming(self):
return self._supports_sync_streaming or self._supports_async_streaming

async def generate_inline_completions(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
Expand Down
29 changes: 29 additions & 0 deletions packages/jupyter-ai-test/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2024, Project Jupyter
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
58 changes: 58 additions & 0 deletions packages/jupyter-ai-test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# jupyter_ai_test

`jupyter_ai_test` is a Jupyter AI module that registers additional model
providers and slash commands for testing Jupyter AI in a local development
environment. This package should never published on NPM or PyPI.

## Requirements

- Python 3.8 - 3.11
- JupyterLab 4

## Install

To install the extension, execute:

```bash
pip install jupyter_ai_test
```

## Uninstall

To remove the extension, execute:

```bash
pip uninstall jupyter_ai_test
```

## Contributing

### Development install

```bash
cd jupyter-ai-test
pip install -e "."
```

### Development uninstall

```bash
pip uninstall jupyter_ai_test
```

#### Backend tests

This package uses [Pytest](https://docs.pytest.org/) for Python testing.

Install test dependencies (needed only once):

```sh
cd jupyter-ai-test
pip install -e ".[test]"
```

To execute them, run:

```sh
pytest -vv -r ap --cov jupyter_ai_test
```
1 change: 1 addition & 0 deletions packages/jupyter-ai-test/jupyter_ai_test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._version import __version__
57 changes: 57 additions & 0 deletions packages/jupyter-ai-test/jupyter_ai_test/test_llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time
from typing import Any, Iterator, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs.generation import GenerationChunk


class TestLLM(LLM):
model_id: str = "test"

@property
def _llm_type(self) -> str:
return "custom"

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
time.sleep(3)
return f"Hello! This is a dummy response from a test LLM."


class TestLLMWithStreaming(LLM):
model_id: str = "test"

@property
def _llm_type(self) -> str:
return "custom"

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
time.sleep(3)
return f"Hello! This is a dummy response from a test LLM."

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
time.sleep(5)
yield GenerationChunk(
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n"
)
for i in range(1, 101):
time.sleep(0.5)
yield GenerationChunk(text=f"{i}, ")
77 changes: 77 additions & 0 deletions packages/jupyter-ai-test/jupyter_ai_test/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import ClassVar, List

from jupyter_ai import AuthStrategy, BaseProvider, Field

from .test_llms import TestLLM, TestLLMWithStreaming


class TestProvider(BaseProvider, TestLLM):
id: ClassVar[str] = "test-provider"
"""ID for this provider class."""

name: ClassVar[str] = "Test Provider"
"""User-facing name of this provider."""

models: ClassVar[List[str]] = ["test"]
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

help: ClassVar[str] = None
"""Text to display in lieu of a model list for a registry provider that does
not provide a list of models."""

model_id_key: ClassVar[str] = "model_id"
"""Kwarg expected by the upstream LangChain provider."""

model_id_label: ClassVar[str] = "Model ID"
"""Human-readable label of the model ID."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

auth_strategy: ClassVar[AuthStrategy] = None
"""Authentication/authorization strategy. Declares what credentials are
required to use this model provider. Generally should not be `None`."""

registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""

fields: ClassVar[List[Field]] = []
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""


class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming):
id: ClassVar[str] = "test-provider-with-streaming"
"""ID for this provider class."""

name: ClassVar[str] = "Test Provider (streaming)"
"""User-facing name of this provider."""

models: ClassVar[List[str]] = ["test"]
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

help: ClassVar[str] = None
"""Text to display in lieu of a model list for a registry provider that does
not provide a list of models."""

model_id_key: ClassVar[str] = "model_id"
"""Kwarg expected by the upstream LangChain provider."""

model_id_label: ClassVar[str] = "Model ID"
"""Human-readable label of the model ID."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

auth_strategy: ClassVar[AuthStrategy] = None
"""Authentication/authorization strategy. Declares what credentials are
required to use this model provider. Generally should not be `None`."""

registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""

fields: ClassVar[List[Field]] = []
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""
29 changes: 29 additions & 0 deletions packages/jupyter-ai-test/jupyter_ai_test/test_slash_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.models import HumanChatMessage


class TestSlashCommand(BaseChatHandler):
"""
A test slash command implementation that developers should build from. The
string used to invoke this command is set by the `slash_id` keyword argument
in the `routing_type` attribute. The command is mainly implemented in the
`process_message()` method. See built-in implementations under
`jupyter_ai/handlers` for further reference.
The provider is made available to Jupyter AI by the entry point declared in
`pyproject.toml`. If this class or parent module is renamed, make sure the
update the entry point there as well.
"""

id = "test"
name = "Test"
help = "A test slash command."
routing_type = SlashCommandRoutingType(slash_id="test")

uses_llm = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, message: HumanChatMessage):
self.reply("This is the `/test` slash command.")
1 change: 1 addition & 0 deletions packages/jupyter-ai-test/jupyter_ai_test/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Python unit tests for jupyter_ai_test."""
25 changes: 25 additions & 0 deletions packages/jupyter-ai-test/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"name": "@jupyter-ai/test",
"version": "2.18.1",
"description": "Jupyter AI test package. Not published on NPM or PyPI.",
"private": true,
"homepage": "https://github.com/jupyterlab/jupyter-ai",
"bugs": {
"url": "https://github.com/jupyterlab/jupyter-ai/issues",
"email": "[email protected]"
},
"license": "BSD-3-Clause",
"author": {
"name": "Project Jupyter",
"email": "[email protected]"
},
"repository": {
"type": "git",
"url": "https://github.com/jupyterlab/jupyter-ai.git"
},
"scripts": {
"dev-install": "pip install -e .",
"dev-uninstall": "pip uninstall jupyter_ai_test -y",
"install-from-src": "pip install ."
}
}
4 changes: 4 additions & 0 deletions packages/jupyter-ai-test/project.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"name": "@jupyter-ai/test",
"implicitDependencies": ["@jupyter-ai/core"]
}
41 changes: 41 additions & 0 deletions packages/jupyter-ai-test/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[build-system]
requires = ["hatchling>=1.4.0", "jupyterlab~=4.0"]
build-backend = "hatchling.build"

[project]
name = "jupyter_ai_test"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.8"
classifiers = [
"Framework :: Jupyter",
"Framework :: Jupyter :: JupyterLab",
"Framework :: Jupyter :: JupyterLab :: 4",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
version = "0.1.0"
description = "A Jupyter AI extension."
authors = [{ name = "Project Jupyter", email = "[email protected]" }]
dependencies = ["jupyter_ai"]

[project.optional-dependencies]
test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"]

[project.entry-points."jupyter_ai.model_providers"]
test-provider = "jupyter_ai_test.test_providers:TestProvider"
test-provider-with-streaming = "jupyter_ai_test.test_providers:TestProviderWithStreaming"

[project.entry-points."jupyter_ai.chat_handlers"]
test-slash-command = "jupyter_ai_test.test_slash_commands:TestSlashCommand"

[tool.hatch.build.hooks.version]
path = "jupyter_ai_test/_version.py"

[tool.check-wheel-contents]
ignore = ["W002"]
1 change: 1 addition & 0 deletions packages/jupyter-ai-test/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__import__("setuptools").setup()
Loading

0 comments on commit f5382ae

Please sign in to comment.