-
-
Notifications
You must be signed in to change notification settings - Fork 340
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow to define block and allow lists for providers (#415)
* Allow to block or allow-list providers by id * Add tests for block/allow-lists * Fix "No language model is associated with" issue This was appearing because the models which are blocked were not returned (correctly!) but the previous validation logic did not know that sometimes models may be missing for a valid reason even if there are existing settings for these. * Add docs for allow listing and block listing providers * Updated docs * Added an intro block to docs * Updated the docs --------- Co-authored-by: Piyush Jain <[email protected]>
- Loading branch information
1 parent
7c09863
commit 92dab10
Showing
8 changed files
with
178 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) Jupyter Development Team. | ||
# Distributed under the terms of the Modified BSD License. | ||
|
||
import pytest | ||
from jupyter_ai_magics.utils import get_lm_providers | ||
|
||
KNOWN_LM_A = "openai" | ||
KNOWN_LM_B = "huggingface_hub" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"restrictions", | ||
[ | ||
{"allowed_providers": None, "blocked_providers": None}, | ||
{"allowed_providers": [], "blocked_providers": []}, | ||
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]}, | ||
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": []}, | ||
], | ||
) | ||
def test_get_lm_providers_not_restricted(restrictions): | ||
a_not_restricted = get_lm_providers(None, restrictions) | ||
assert KNOWN_LM_A in a_not_restricted | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"restrictions", | ||
[ | ||
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]}, | ||
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": []}, | ||
], | ||
) | ||
def test_get_lm_providers_restricted(restrictions): | ||
a_not_restricted = get_lm_providers(None, restrictions) | ||
assert KNOWN_LM_A not in a_not_restricted |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) Jupyter Development Team. | ||
# Distributed under the terms of the Modified BSD License. | ||
import pytest | ||
from jupyter_ai.extension import AiExtension | ||
|
||
pytest_plugins = ["pytest_jupyter.jupyter_server"] | ||
|
||
KNOWN_LM_A = "openai" | ||
KNOWN_LM_B = "huggingface_hub" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"argv", | ||
[ | ||
["--AiExtension.blocked_providers", KNOWN_LM_B], | ||
["--AiExtension.allowed_providers", KNOWN_LM_A], | ||
], | ||
) | ||
def test_allows_providers(argv, jp_configurable_serverapp): | ||
server = jp_configurable_serverapp(argv=argv) | ||
ai = AiExtension() | ||
ai._link_jupyter_server_extension(server) | ||
ai.initialize_settings() | ||
assert KNOWN_LM_A in ai.settings["lm_providers"] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"argv", | ||
[ | ||
["--AiExtension.blocked_providers", KNOWN_LM_A], | ||
["--AiExtension.allowed_providers", KNOWN_LM_B], | ||
], | ||
) | ||
def test_blocks_providers(argv, jp_configurable_serverapp): | ||
server = jp_configurable_serverapp(argv=argv) | ||
ai = AiExtension() | ||
ai._link_jupyter_server_extension(server) | ||
ai.initialize_settings() | ||
assert KNOWN_LM_A not in ai.settings["lm_providers"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,7 @@ test = [ | |
"pytest-asyncio", | ||
"pytest-cov", | ||
"pytest-tornasync", | ||
"pytest-jupyter", | ||
"syrupy~=4.0.8" | ||
] | ||
|
||
|