From 741c5cf961d4c6b16b3a743d058109ca196a2162 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krassowski?= <5832902+krassowski@users.noreply.github.com> Date: Mon, 23 Oct 2023 20:42:11 +0100 Subject: [PATCH] Allow to define block and allow lists for providers (#415) * Allow to block or allow-list providers by id * Add tests for block/allow-lists * Fix "No language model is associated with" issue This was appearing because the models which are blocked were not returned (correctly!) but the previous validation logic did not know that sometimes models may be missing for a valid reason even if there are existing settings for these. * Add docs for allow listing and block listing providers * Updated docs * Added an intro block to docs * Updated the docs --------- Co-authored-by: Piyush Jain --- docs/source/users/index.md | 31 +++++++++++++++ .../jupyter_ai_magics/tests/test_utils.py | 34 ++++++++++++++++ .../jupyter_ai_magics/utils.py | 32 +++++++++++++-- .../jupyter-ai/jupyter_ai/config_manager.py | 15 +++++++ packages/jupyter-ai/jupyter_ai/extension.py | 30 +++++++++++++- .../jupyter_ai/tests/test_config_manager.py | 2 + .../jupyter_ai/tests/test_extension.py | 39 +++++++++++++++++++ packages/jupyter-ai/pyproject.toml | 1 + 8 files changed, 178 insertions(+), 6 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py create mode 100644 packages/jupyter-ai/jupyter_ai/tests/test_extension.py diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 28f73193d..72b826a67 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -748,3 +748,34 @@ The `--region-name` parameter is set to the [AWS region code](https://docs.aws.a The `--request-schema` parameter is the JSON object the endpoint expects as input, with the prompt being substituted into any value that matches the string literal `""`. For example, the request schema `{"text_inputs":""}` will submit a JSON object with the prompt stored under the `text_inputs` key. The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonPath/index.html) string that retrieves the language model's output from the endpoint's JSON response. For example, if your endpoint returns an object with the schema `{"generated_texts":[""]}`, its response path is `generated_texts.[0]`. + + +## Configuration + +You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers. + +### Blocklisting providers +This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. + +``` +jupyter lab --AiExtension.blocked_providers=openai +``` + +To block more than one provider in the block-list, repeat the runtime configuration. + +``` +jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21 +``` + +### Allowlisting providers +This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers. + +``` +jupyter lab --AiExtension.allowed_providers=openai +``` + +To allow more than one provider in the allowlist, repeat the runtime configuration. + +``` +jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21 +``` diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py new file mode 100644 index 000000000..e1c517ebe --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import pytest +from jupyter_ai_magics.utils import get_lm_providers + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": None, "blocked_providers": None}, + {"allowed_providers": [], "blocked_providers": []}, + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]}, + {"allowed_providers": [KNOWN_LM_A], "blocked_providers": []}, + ], +) +def test_get_lm_providers_not_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A in a_not_restricted + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]}, + {"allowed_providers": [KNOWN_LM_B], "blocked_providers": []}, + ], +) +def test_get_lm_providers_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A not in a_not_restricted diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 6a02c61c8..c651581bc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, Union from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES @@ -11,13 +11,19 @@ EmProvidersDict = Dict[str, BaseEmbeddingsProvider] AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider] ProviderDict = Dict[str, AnyProvider] +ProviderRestrictions = Dict[ + Literal["allowed_providers", "blocked_providers"], Optional[List[str]] +] -def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: +def get_lm_providers( + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None +) -> LmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) - + if not restrictions: + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -29,6 +35,9 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: f"Unable to load model provider class from entry point `{model_provider_ep.name}`." ) continue + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered model provider `{provider.id}`.") @@ -36,11 +45,13 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: def get_em_providers( - log: Optional[Logger] = None, + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None ) -> EmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) + if not restrictions: + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") @@ -52,6 +63,9 @@ def get_em_providers( f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`." ) continue + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered embeddings model provider `{provider.id}`.") @@ -97,6 +111,16 @@ def get_em_provider( return _get_provider(model_id, em_providers) +def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool: + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] + if blocked and provider_id in blocked: + return False + if allowed and provider_id not in allowed: + return False + return True + + def _get_provider(model_id: str, providers: ProviderDict): provider_id, local_model_id = decompose_model_id(model_id, providers) provider = providers.get(provider_id, None) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index f61e07bde..8708b4b94 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -12,8 +12,10 @@ AnyProvider, EmProvidersDict, LmProvidersDict, + ProviderRestrictions, get_em_provider, get_lm_provider, + is_provider_allowed, ) from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode @@ -97,6 +99,7 @@ def __init__( log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, + restrictions: ProviderRestrictions, *args, **kwargs, ): @@ -106,6 +109,8 @@ def __init__( self._lm_providers = lm_providers """List of EM providers.""" self._em_providers = em_providers + """Provider restrictions.""" + self._restrictions = restrictions """When the server last read the config file. If the file was not modified after this time, then we can return the cached @@ -176,6 +181,10 @@ def _validate_config(self, config: GlobalConfig): _, lm_provider = get_lm_provider( config.model_provider_id, self._lm_providers ) + # do not check config for blocked providers + if not is_provider_allowed(config.model_provider_id, self._restrictions): + assert not lm_provider + return if not lm_provider: raise ValueError( f"No language model is associated with '{config.model_provider_id}'." @@ -187,6 +196,12 @@ def _validate_config(self, config: GlobalConfig): _, em_provider = get_em_provider( config.embeddings_provider_id, self._em_providers ) + # do not check config for blocked providers + if not is_provider_allowed( + config.embeddings_provider_id, self._restrictions + ): + assert not em_provider + return if not em_provider: raise ValueError( f"No embedding model is associated with '{config.embeddings_provider_id}'." diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 7b6d07b31..50865ed96 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,6 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp +from traitlets import List, Unicode from .chat_handlers import ( AskChatHandler, @@ -36,11 +37,35 @@ class AiExtension(ExtensionApp): (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), ] + allowed_providers = List( + Unicode(), + default_value=None, + help="Identifiers of allow-listed providers. If `None`, all are allowed.", + allow_none=True, + config=True, + ) + + blocked_providers = List( + Unicode(), + default_value=None, + help="Identifiers of block-listed providers. If `None`, none are blocked.", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() + restrictions = { + "allowed_providers": self.allowed_providers, + "blocked_providers": self.blocked_providers, + } - self.settings["lm_providers"] = get_lm_providers(log=self.log) - self.settings["em_providers"] = get_em_providers(log=self.log) + self.settings["lm_providers"] = get_lm_providers( + log=self.log, restrictions=restrictions + ) + self.settings["em_providers"] = get_em_providers( + log=self.log, restrictions=restrictions + ) self.settings["jai_config_manager"] = ConfigManager( # traitlets configuration, not JAI configuration. @@ -48,6 +73,7 @@ def initialize_settings(self): log=self.log, lm_providers=self.settings["lm_providers"], em_providers=self.settings["em_providers"], + restrictions=restrictions, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index f7afbd29a..8cb1808fe 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -36,6 +36,7 @@ def common_cm_kwargs(config_path, schema_path): "em_providers": em_providers, "config_path": config_path, "schema_path": schema_path, + "restrictions": {"allowed_providers": None, "blocked_providers": None}, } @@ -112,6 +113,7 @@ def test_init_with_existing_config( em_providers=em_providers, config_path=config_path, schema_path=schema_path, + restrictions={"allowed_providers": None, "blocked_providers": None}, ) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py new file mode 100644 index 000000000..d1a10df77 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -0,0 +1,39 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. +import pytest +from jupyter_ai.extension import AiExtension + +pytest_plugins = ["pytest_jupyter.jupyter_server"] + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_B], + ["--AiExtension.allowed_providers", KNOWN_LM_A], + ], +) +def test_allows_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A in ai.settings["lm_providers"] + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_A], + ["--AiExtension.allowed_providers", KNOWN_LM_B], + ], +) +def test_blocks_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A not in ai.settings["lm_providers"] diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 2e56e0b07..99dd13ddd 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -51,6 +51,7 @@ test = [ "pytest-asyncio", "pytest-cov", "pytest-tornasync", + "pytest-jupyter", "syrupy~=4.0.8" ]