forked from jupyterlab/jupyter-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow configuring a default model for cell magics (and line error magic)
- Loading branch information
1 parent
83e368b
commit c3d8ed4
Showing
3 changed files
with
78 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 40 additions & 1 deletion
41
packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,50 @@ | ||
from unittest.mock import patch | ||
|
||
from IPython import InteractiveShell | ||
from jupyter_ai_magics.magics import AiMagics | ||
from pytest import fixture | ||
from traitlets.config.loader import Config | ||
|
||
|
||
def test_aliases_config(): | ||
@fixture | ||
def ip() -> InteractiveShell: | ||
ip = InteractiveShell() | ||
ip.config = Config() | ||
return ip | ||
|
||
|
||
def test_aliases_config(ip): | ||
ip.config.AiMagics.aliases = {"my_custom_alias": "my_provider:my_model"} | ||
ip.extension_manager.load_extension("jupyter_ai_magics") | ||
providers_list = ip.run_line_magic("ai", "list").text | ||
assert "my_custom_alias" in providers_list | ||
|
||
|
||
def test_default_model_cell(ip): | ||
ip.config.AiMagics.default_language_model = "my-favourite-llm" | ||
ip.extension_manager.load_extension("jupyter_ai_magics") | ||
with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: | ||
ip.run_cell_magic("ai", "", cell="Write code for me please") | ||
assert mock_run.called | ||
cell_args = mock_run.call_args.args[0] | ||
assert cell_args.model_id == "my-favourite-llm" | ||
|
||
|
||
def test_non_default_model_cell(ip): | ||
ip.config.AiMagics.default_language_model = "my-favourite-llm" | ||
ip.extension_manager.load_extension("jupyter_ai_magics") | ||
with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: | ||
ip.run_cell_magic("ai", "some-different-llm", cell="Write code for me please") | ||
assert mock_run.called | ||
cell_args = mock_run.call_args.args[0] | ||
assert cell_args.model_id == "some-different-llm" | ||
|
||
|
||
def test_default_model_error_line(ip): | ||
ip.config.AiMagics.default_language_model = "my-favourite-llm" | ||
ip.extension_manager.load_extension("jupyter_ai_magics") | ||
with patch.object(AiMagics, "handle_error", return_value=None) as mock_run: | ||
ip.run_cell_magic("ai", "error", cell=None) | ||
assert mock_run.called | ||
cell_args = mock_run.call_args.args[0] | ||
assert cell_args.model_id == "my-favourite-llm" |