Skip to content

Commit

Permalink
Allow to define block and allow lists for providers (#415)
Browse files Browse the repository at this point in the history
* Allow to block or allow-list providers by id

* Add tests for block/allow-lists

* Fix "No language model is associated with" issue

This was appearing because the models which are blocked were not
returned (correctly!) but the previous validation logic did not
know that sometimes models may be missing for a valid reason even
if there are existing settings for these.

* Add docs for allow listing and block listing providers

* Updated docs

* Added an intro block to docs

* Updated the docs

---------

Co-authored-by: Piyush Jain <[email protected]>
  • Loading branch information
krassowski and 3coins authored Oct 23, 2023
1 parent 7c09863 commit 92dab10
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 6 deletions.
31 changes: 31 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,34 @@ The `--region-name` parameter is set to the [AWS region code](https://docs.aws.a
The `--request-schema` parameter is the JSON object the endpoint expects as input, with the prompt being substituted into any value that matches the string literal `"<prompt>"`. For example, the request schema `{"text_inputs":"<prompt>"}` will submit a JSON object with the prompt stored under the `text_inputs` key.

The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonPath/index.html) string that retrieves the language model's output from the endpoint's JSON response. For example, if your endpoint returns an object with the schema `{"generated_texts":["<output>"]}`, its response path is `generated_texts.[0]`.


## Configuration

You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers.

### Blocklisting providers
This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section.

```
jupyter lab --AiExtension.blocked_providers=openai
```

To block more than one provider in the block-list, repeat the runtime configuration.

```
jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21
```

### Allowlisting providers
This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers.

```
jupyter lab --AiExtension.allowed_providers=openai
```

To allow more than one provider in the allowlist, repeat the runtime configuration.

```
jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21
```
34 changes: 34 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import pytest
from jupyter_ai_magics.utils import get_lm_providers

KNOWN_LM_A = "openai"
KNOWN_LM_B = "huggingface_hub"


@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": None, "blocked_providers": None},
{"allowed_providers": [], "blocked_providers": []},
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]},
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": []},
],
)
def test_get_lm_providers_not_restricted(restrictions):
a_not_restricted = get_lm_providers(None, restrictions)
assert KNOWN_LM_A in a_not_restricted


@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]},
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": []},
],
)
def test_get_lm_providers_restricted(restrictions):
a_not_restricted = get_lm_providers(None, restrictions)
assert KNOWN_LM_A not in a_not_restricted
32 changes: 28 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional, Tuple, Type, Union
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
Expand All @@ -11,13 +11,19 @@
EmProvidersDict = Dict[str, BaseEmbeddingsProvider]
AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider]
ProviderDict = Dict[str, AnyProvider]
ProviderRestrictions = Dict[
Literal["allowed_providers", "blocked_providers"], Optional[List[str]]
]


def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:
def get_lm_providers(
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> LmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())

if not restrictions:
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
Expand All @@ -29,18 +35,23 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
)
continue
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered model provider `{provider.id}`.")

return providers


def get_em_providers(
log: Optional[Logger] = None,
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> EmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
if not restrictions:
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers")
Expand All @@ -52,6 +63,9 @@ def get_em_providers(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
)
continue
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered embeddings model provider `{provider.id}`.")

Expand Down Expand Up @@ -97,6 +111,16 @@ def get_em_provider(
return _get_provider(model_id, em_providers)


def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool:
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
if blocked and provider_id in blocked:
return False
if allowed and provider_id not in allowed:
return False
return True


def _get_provider(model_id: str, providers: ProviderDict):
provider_id, local_model_id = decompose_model_id(model_id, providers)
provider = providers.get(provider_id, None)
Expand Down
15 changes: 15 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
AnyProvider,
EmProvidersDict,
LmProvidersDict,
ProviderRestrictions,
get_em_provider,
get_lm_provider,
is_provider_allowed,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
*args,
**kwargs,
):
Expand All @@ -106,6 +109,8 @@ def __init__(
self._lm_providers = lm_providers
"""List of EM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions

"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
Expand Down Expand Up @@ -176,6 +181,10 @@ def _validate_config(self, config: GlobalConfig):
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
# do not check config for blocked providers
if not is_provider_allowed(config.model_provider_id, self._restrictions):
assert not lm_provider
return
if not lm_provider:
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
Expand All @@ -187,6 +196,12 @@ def _validate_config(self, config: GlobalConfig):
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
# do not check config for blocked providers
if not is_provider_allowed(
config.embeddings_provider_id, self._restrictions
):
assert not em_provider
return
if not em_provider:
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
Expand Down
30 changes: 28 additions & 2 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from traitlets import List, Unicode

from .chat_handlers import (
AskChatHandler,
Expand Down Expand Up @@ -36,18 +37,43 @@ class AiExtension(ExtensionApp):
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
]

allowed_providers = List(
Unicode(),
default_value=None,
help="Identifiers of allow-listed providers. If `None`, all are allowed.",
allow_none=True,
config=True,
)

blocked_providers = List(
Unicode(),
default_value=None,
help="Identifiers of block-listed providers. If `None`, none are blocked.",
allow_none=True,
config=True,
)

def initialize_settings(self):
start = time.time()
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}

self.settings["lm_providers"] = get_lm_providers(log=self.log)
self.settings["em_providers"] = get_em_providers(log=self.log)
self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
)
self.settings["em_providers"] = get_em_providers(
log=self.log, restrictions=restrictions
)

self.settings["jai_config_manager"] = ConfigManager(
# traitlets configuration, not JAI configuration.
config=self.config,
log=self.log,
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
)

self.log.info("Registered providers.")
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def common_cm_kwargs(config_path, schema_path):
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
}


Expand Down Expand Up @@ -112,6 +113,7 @@ def test_init_with_existing_config(
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
)


Expand Down
39 changes: 39 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import pytest
from jupyter_ai.extension import AiExtension

pytest_plugins = ["pytest_jupyter.jupyter_server"]

KNOWN_LM_A = "openai"
KNOWN_LM_B = "huggingface_hub"


@pytest.mark.parametrize(
"argv",
[
["--AiExtension.blocked_providers", KNOWN_LM_B],
["--AiExtension.allowed_providers", KNOWN_LM_A],
],
)
def test_allows_providers(argv, jp_configurable_serverapp):
server = jp_configurable_serverapp(argv=argv)
ai = AiExtension()
ai._link_jupyter_server_extension(server)
ai.initialize_settings()
assert KNOWN_LM_A in ai.settings["lm_providers"]


@pytest.mark.parametrize(
"argv",
[
["--AiExtension.blocked_providers", KNOWN_LM_A],
["--AiExtension.allowed_providers", KNOWN_LM_B],
],
)
def test_blocks_providers(argv, jp_configurable_serverapp):
server = jp_configurable_serverapp(argv=argv)
ai = AiExtension()
ai._link_jupyter_server_extension(server)
ai.initialize_settings()
assert KNOWN_LM_A not in ai.settings["lm_providers"]
1 change: 1 addition & 0 deletions packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ test = [
"pytest-asyncio",
"pytest-cov",
"pytest-tornasync",
"pytest-jupyter",
"syrupy~=4.0.8"
]

Expand Down

0 comments on commit 92dab10

Please sign in to comment.