Skip to content

Commit

Permalink
Add back history and reset subcommand in magics (jupyterlab#997)
Browse files Browse the repository at this point in the history
* Add back history and reset subcommand in magics

This time it's not specific to OpenAI.

* Document `%ai reset`

* Add the `max_history` magic command option

* Respect `max_history` when calling the model

All history is still kept so to allow configuring a higher `max_history`
mid-session and have older exchanges be included.

* Use `<role>:` prefix instead of pseudo-XML for non-chat models

* Document the `max_history` setting for the magic commands

* fix: make `max_history` a configuration option

* fix: `max_history=0` to generate with no context

* Add unit test for `max_history` magic option

* Add unit test for `%ai reset` magic command

* Test that transcript is updated with new prompt

* Review corrections

- `context` instead of `self.transcript` as if clause condition
- less confusing model context input variable
  • Loading branch information
akaihola authored Oct 7, 2024
1 parent 1ec0936 commit b15ebb5
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 4 deletions.
28 changes: 28 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,34 @@ A function that computes the lowest common multiples of two integers, and
a function that runs 5 test cases of the lowest common multiple function
```

### Configuring the amount of history to include in the context

By default, two previous Human/AI message exchanges are included in the context of the new prompt.
You can change this using the IPython `%config` magic, for example:

```python
%config AiMagics.max_history = 4
```

Note that old messages are still kept locally in memory,
so they will be included in the context of the next prompt after raising the `max_history` value.

You can configure the value for all notebooks
by specifying `c.AiMagics.max_history` traitlet in `ipython_config.py`, for example:

```python
c.AiMagics.max_history = 4
```

### Clearing the chat history

You can run the `%ai reset` line magic command to clear the chat history. After you do this,
previous magic commands you've run will no longer be added as context in requests.

```
%ai reset
```

### Interpolating in prompts

Using curly brace syntax, you can include variables and other Python expressions in your
Expand Down
44 changes: 41 additions & 3 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
from langchain.chains import LLMChain
from langchain.schema import HumanMessage
from langchain_core.messages import AIMessage

from ._version import __version__
from .parsers import (
Expand All @@ -24,6 +25,7 @@
HelpArgs,
ListArgs,
RegisterArgs,
ResetArgs,
UpdateArgs,
VersionArgs,
cell_magic_parser,
Expand Down Expand Up @@ -144,9 +146,18 @@ class AiMagics(Magics):
config=True,
)

max_history = traitlets.Int(
default_value=2,
allow_none=False,
help="""Maximum number of exchanges (user/assistant) to include in the history
when invoking a chat model, defaults to 2.
""",
config=True,
)

def __init__(self, shell):
super().__init__(shell)
self.transcript_openai = []
self.transcript = []

# suppress warning when using old Anthropic provider
warnings.filterwarnings(
Expand Down Expand Up @@ -437,6 +448,12 @@ def handle_error(self, args: ErrorArgs):

return self.run_ai_cell(cell_args, prompt)

def _append_exchange(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
self.transcript.append(HumanMessage(prompt))
self.transcript.append(AIMessage(output))

def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
# custom_model_registry maps keys to either a model name (a string) or an LLMChain.
Expand Down Expand Up @@ -500,6 +517,9 @@ def handle_list(self, args: ListArgs):
def handle_version(self, args: VersionArgs):
return __version__

def handle_reset(self, args: ResetArgs):
self.transcript = []

def run_ai_cell(self, args: CellArgs, prompt: str):
provider_id, local_model_id = self._decompose_model_id(args.model_id)

Expand Down Expand Up @@ -577,13 +597,29 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
ip = self.shell
prompt = prompt.format_map(FormatDict(ip.user_ns))

context = self.transcript[-2 * self.max_history :] if self.max_history else []
if provider.is_chat_provider:
result = provider.generate([[HumanMessage(content=prompt)]])
result = provider.generate([[*context, HumanMessage(content=prompt)]])
else:
# generate output from model via provider
result = provider.generate([prompt])
if context:
inputs = [
(
f"AI: {message.content}"
if message.type == "ai"
else f"{message.type.title()}: {message.content}"
)
for message in context + [HumanMessage(content=prompt)]
]
else:
inputs = [prompt]
result = provider.generate(inputs)

output = result.generations[0][0].text

# append exchange to transcript
self._append_exchange(prompt, output)

md = {"jupyter_ai": {"provider_id": provider_id, "model_id": local_model_id}}

return self.display_output(output, args.format, md)
Expand Down Expand Up @@ -628,6 +664,8 @@ def ai(self, line, cell=None):
return self.handle_update(args)
if args.type == "version":
return self.handle_version(args)
if args.type == "reset":
return self.handle_reset(args)
except ValueError as e:
print(e, file=sys.stderr)
return
Expand Down
13 changes: 13 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class UpdateArgs(BaseModel):
target: str


class ResetArgs(BaseModel):
type: Literal["reset"] = "reset"


class LineMagicGroup(click.Group):
"""Helper class to print the help string for cell magics as well when
`%ai --help` is called."""
Expand Down Expand Up @@ -277,3 +281,12 @@ def register_subparser(**kwargs):
def register_subparser(**kwargs):
"""Update an alias called NAME to refer to the model or chain named TARGET."""
return UpdateArgs(**kwargs)


@line_magic_parser.command(
name="reset",
short_help="Clear the conversation transcript.",
)
def register_subparser(**kwargs):
"""Clear the conversation transcript."""
return ResetArgs()
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import patch
import os
from unittest.mock import Mock, patch

import pytest
from IPython import InteractiveShell
from IPython.core.display import Markdown
from jupyter_ai_magics.magics import AiMagics
from langchain_core.messages import AIMessage, HumanMessage
from pytest import fixture
from traitlets.config.loader import Config

Expand Down Expand Up @@ -48,3 +52,71 @@ def test_default_model_error_line(ip):
assert mock_run.called
cell_args = mock_run.call_args.args[0]
assert cell_args.model_id == "my-favourite-llm"


PROMPT = HumanMessage(
content=("Write code for me please\n\nProduce output in markdown format only.")
)
RESPONSE = AIMessage(content="Leet code")
AI1 = AIMessage("ai1")
H1 = HumanMessage("h1")
AI2 = AIMessage("ai2")
H2 = HumanMessage("h2")
AI3 = AIMessage("ai3")


@pytest.mark.parametrize(
["transcript", "max_history", "expected_context"],
[
([], 3, [PROMPT]),
([AI1], 0, [PROMPT]),
([AI1], 1, [AI1, PROMPT]),
([H1, AI1], 0, [PROMPT]),
([H1, AI1], 1, [H1, AI1, PROMPT]),
([AI1, H1, AI2], 0, [PROMPT]),
([AI1, H1, AI2], 1, [H1, AI2, PROMPT]),
([AI1, H1, AI2], 2, [AI1, H1, AI2, PROMPT]),
([H1, AI1, H2, AI2], 0, [PROMPT]),
([H1, AI1, H2, AI2], 1, [H2, AI2, PROMPT]),
([H1, AI1, H2, AI2], 2, [H1, AI1, H2, AI2, PROMPT]),
([AI1, H1, AI2, H2, AI3], 0, [PROMPT]),
([AI1, H1, AI2, H2, AI3], 1, [H2, AI3, PROMPT]),
([AI1, H1, AI2, H2, AI3], 2, [H1, AI2, H2, AI3, PROMPT]),
([AI1, H1, AI2, H2, AI3], 3, [AI1, H1, AI2, H2, AI3, PROMPT]),
],
)
def test_max_history(ip, transcript, max_history, expected_context):
ip.extension_manager.load_extension("jupyter_ai_magics")
ai_magics = ip.magics_manager.registry["AiMagics"]
ai_magics.transcript = transcript.copy()
ai_magics.max_history = max_history
provider = ai_magics._get_provider("openrouter")
with patch.object(provider, "generate") as generate, patch.dict(
os.environ, OPENROUTER_API_KEY="123"
):
generate.return_value.generations = [[Mock(text="Leet code")]]
result = ip.run_cell_magic(
"ai",
"openrouter:anthropic/claude-3.5-sonnet",
cell="Write code for me please",
)
provider.generate.assert_called_once_with([expected_context])
assert isinstance(result, Markdown)
assert result.data == "Leet code"
assert result.filename is None
assert result.metadata == {
"jupyter_ai": {
"model_id": "anthropic/claude-3.5-sonnet",
"provider_id": "openrouter",
}
}
assert result.url is None
assert ai_magics.transcript == [*transcript, PROMPT, RESPONSE]


def test_reset(ip):
ip.extension_manager.load_extension("jupyter_ai_magics")
ai_magics = ip.magics_manager.registry["AiMagics"]
ai_magics.transcript = [AI1, H1, AI2, H2, AI3]
result = ip.run_line_magic("ai", "reset")
assert ai_magics.transcript == []

0 comments on commit b15ebb5

Please sign in to comment.