diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 6a02c61c8..6a21eb763 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Type, Union, Literal, List from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES @@ -13,11 +13,26 @@ ProviderDict = Dict[str, AnyProvider] -def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: +ProviderRestrictions = Dict[ + Literal["allowed_providers", "blocked_providers"], + Optional[List[str]] +] + + +def get_lm_providers( + log: Optional[Logger] = None, + restrictions: Optional[ProviderRestrictions] = None +) -> LmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) - + if not restrictions: + restrictions = { + "allowed_providers": None, + "blocked_providers": None + } + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -29,6 +44,12 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: f"Unable to load model provider class from entry point `{model_provider_ep.name}`." ) continue + if blocked and provider.id in blocked: + log.info(f"Skipping provider not on block-list `{provider.id}`.") + continue + if allowed and provider.id not in allowed: + log.info(f"Skipping provider not on allow-list `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered model provider `{provider.id}`.") @@ -37,10 +58,18 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: def get_em_providers( log: Optional[Logger] = None, + restrictions: Optional[ProviderRestrictions] = None ) -> EmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) + if not restrictions: + restrictions = { + "allowed_providers": None, + "blocked_providers": None + } + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") @@ -52,6 +81,12 @@ def get_em_providers( f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`." ) continue + if blocked and provider.id in blocked: + log.info(f"Skipping provider not on block-list `{provider.id}`.") + continue + if allowed and provider.id not in allowed: + log.info(f"Skipping provider not on allow-list `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered embeddings model provider `{provider.id}`.") diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 7b6d07b31..8a9679741 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,6 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp +from traitlets import List, Unicode from .chat_handlers import ( AskChatHandler, @@ -36,6 +37,22 @@ class AiExtension(ExtensionApp): (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), ] + allowed_providers = List( + Unicode(), + default_value=None, + help="Identifiers of allow-listed providers. If `None`, all are allowed.", + allow_none=True, + config=True, + ) + + blocked_providers = List( + Unicode(), + default_value=None, + help="Identifiers of block-listed providers. If `None`, none are blocked.", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time()