Skip to content

Commit

Permalink
Incorporated naming related comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-khatria committed Jan 25, 2024
1 parent ff73fac commit 4d57c31
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 173 deletions.
6 changes: 3 additions & 3 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -770,9 +770,9 @@ a blocklist, to block some providers.
This configuration allows for setting a default model, and embedding provider with their corresponding API keys.

Following command line arguments can be used to initialize Jupyter AI extension with default providers:
1. ```--AiExtension.default_model_provider_id```: Specify the default LLM model. E.g.: ```--AiExtension.default_model_provider_id="bedrock-chat:anthropic.claude-v2"```
2. ```--AiExtension.default_embeddings_provider_id```: Specify default embedding provider. E.g.: ```--AiExtension.default_embeddings_provider_id="bedrock:amazon.titan-embed-text-v1"```
3. ```--AiExtension.default_api_keys```: Specify model provider keys in a JSON format. E.g. ```--AiExtension.default_api_keys='{"OPENAI_API_KEY": "sk-abcd}'```
1. ```--AiExtension.default_language_model```: Specify the default LLM model. Sample value: ```--AiExtension.default_language_model="bedrock-chat:anthropic.claude-v2"```
2. ```--AiExtension.default_embedding_model```: Specify default embedding model. Sample value: ```--AiExtension.default_embedding_model="bedrock:amazon.titan-embed-text-v1"```
3. ```--AiExtension.api_keys```: Specify model keys in a JSON format. Sample CLI argument: ```--AiExtension.api_keys='{"OPENAI_API_KEY": "sk-abcd}'```


### Blocklisting providers
Expand Down
11 changes: 6 additions & 5 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
blocked_providers: Optional[List[str]],
allowed_models: Optional[List[str]],
blocked_models: Optional[List[str]],
provider_defaults: dict,
defaults: dict,
*args,
**kwargs,
):
Expand All @@ -121,7 +121,7 @@ def __init__(
self._blocked_providers = blocked_providers
self._allowed_models = allowed_models
self._blocked_models = blocked_models
self._provider_defaults = provider_defaults
self._defaults = defaults
"""Provider defaults."""

self._last_read: Optional[int] = None
Expand Down Expand Up @@ -158,7 +158,8 @@ def _init_config(self):
def _process_existing_config(self, default_config):
with open(self.config_path, encoding="utf-8") as f:
existing_config = json.loads(f.read())
merged_config = {**existing_config, **default_config}
merged_config = Merger.merge(default_config,
{k: v for k, v in existing_config.items() if v is not None})
config = GlobalConfig(**merged_config)
validated_config = self._validate_lm_em_id(config)

Expand Down Expand Up @@ -207,11 +208,11 @@ def _init_defaults(self):
field_dict = {
field: properties.get(field).get("default") for field in field_list
}
if self._provider_defaults is None:
if self._defaults is None:
return field_dict

for field in field_list:
default_value = self._provider_defaults.get(field)
default_value = self._defaults.get(field)
if default_value is not None:
field_dict[field] = default_value
return field_dict
Expand Down
25 changes: 14 additions & 11 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ class AiExtension(ExtensionApp):
config=True,
)

default_api_keys = Dict(
api_keys = Dict(
key_trait=Unicode(),
value_trait=Unicode(),
default_value=None,
allow_none=True,
help="API keys for language model providers.",
help="""API keys for model providers, as a dictionary, in the format
`<key-name>:<key-value>`. Defaults to None.""",
config=True,
)

Expand All @@ -115,17 +116,19 @@ class AiExtension(ExtensionApp):
config=True,
)

default_model_provider_id = Unicode(
default_language_model = Unicode(
default_value=None,
allow_none=True,
help="Default language model provider.",
help="""Default language model to use, as string in the format
<provider-id>:<model-id>, defaults to None.""",
config=True,
)

default_embeddings_provider_id = Unicode(
default_embedding_model = Unicode(
default_value=None,
allow_none=True,
help="Default embeddings model provider.",
help="""Default embedding model to use, as string in the format
<provider-id>:<model-id>, defaults to None.""",
config=True,
)

Expand All @@ -147,10 +150,10 @@ def initialize_settings(self):
self.settings["model_parameters"] = self.model_parameters
self.log.info(f"Configured model parameters: {self.model_parameters}")

provider_defaults = {
"model_provider_id": self.default_model_provider_id,
"embeddings_provider_id": self.default_embeddings_provider_id,
"api_keys": self.default_api_keys,
defaults = {
"model_provider_id": self.default_language_model,
"embeddings_provider_id": self.default_embedding_model,
"api_keys": self.api_keys,
"fields": self.model_parameters,
}

Expand All @@ -172,7 +175,7 @@ def initialize_settings(self):
blocked_providers=self.blocked_providers,
allowed_models=self.allowed_models,
blocked_models=self.blocked_models,
provider_defaults=provider_defaults,
defaults=defaults,
)

self.log.info("Registered providers.")
Expand Down
167 changes: 13 additions & 154 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def common_cm_kwargs(config_path, schema_path):
"allowed_models": None,
"blocked_models": None,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"provider_defaults": {
"defaults": {
"model_provider_id": None,
"embeddings_provider_id": None,
"api_keys": None,
Expand All @@ -52,19 +52,14 @@ def common_cm_kwargs(config_path, schema_path):


@pytest.fixture
def cm_kargs_with_defaults(config_path, schema_path):
def cm_kargs_with_defaults(config_path, schema_path, common_cm_kwargs):
"""Kwargs that are commonly used when initializing the CM."""
log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
return {
"log": log,
"lm_providers": lm_providers,
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"provider_defaults": {
**common_cm_kwargs,
"defaults": {
"model_provider_id": "bedrock-chat:anthropic.claude-v1",
"embeddings_provider_id": "bedrock:amazon.titan-embed-text-v1",
"api_keys": {"OPENAI_API_KEY": "open-ai-key-value"},
Expand Down Expand Up @@ -103,12 +98,7 @@ def cm_with_allowlists(common_cm_kwargs):
}
return ConfigManager(**kwargs)


def cm_with_defaults(cm_kargs_with_defaults):
"""The default ConfigManager instance, with an empty config and config schema."""
return ConfigManager(**cm_kargs_with_defaults)


@pytest.fixture
def cm_with_defaults(cm_kargs_with_defaults):
"""The default ConfigManager instance, with an empty config and config schema."""
return ConfigManager(**cm_kargs_with_defaults)
Expand Down Expand Up @@ -229,95 +219,7 @@ def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs):


def test_init_with_default_values(
cm_with_defaults: ConfigManager, config_path: str, schema_path: str
):
"""
Test that the ConfigManager initializes with the expected default values.
Args:
cm_with_defaults (ConfigManager): A ConfigManager instance with default values.
config_path (str): The path to the configuration file.
schema_path (str): The path to the schema file.
"""
config_response = cm_with_defaults.get_config()
# assert config response
assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1"
assert (
config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1"
)
assert config_response.api_keys == ["OPENAI_API_KEY"]
assert config_response.fields == {
"bedrock-chat:anthropic.claude-v1": {
"credentials_profile_name": "default",
"region_name": "us-west-2",
}
}

del cm_with_defaults

log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
cm_with_defaults_override = ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"},
)


def test_init_with_default_values(
cm_with_defaults: ConfigManager, config_path: str, schema_path: str
):
"""
Test that the ConfigManager initializes with the expected default values.
Args:
cm_with_defaults (ConfigManager): A ConfigManager instance with default values.
config_path (str): The path to the configuration file.
schema_path (str): The path to the schema file.
"""
config_response = cm_with_defaults.get_config()
# assert config response
assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1"
assert (
config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1"
)
assert config_response.api_keys == ["OPENAI_API_KEY"]
assert config_response.fields == {
"bedrock-chat:anthropic.claude-v1": {
"credentials_profile_name": "default",
"region_name": "us-west-2",
}
}

del cm_with_defaults

log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
cm_with_defaults_override = ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"},
)

assert (
cm_with_defaults_override.get_config().model_provider_id
== "bedrock-chat:anthropic.claude-v2"
)


def test_init_with_default_values(
cm_with_defaults: ConfigManager, config_path: str, schema_path: str
):
cm_with_defaults: ConfigManager, config_path: str, schema_path: str, common_cm_kwargs):
"""
Test that the ConfigManager initializes with the expected default values.
Expand Down Expand Up @@ -345,60 +247,17 @@ def test_init_with_default_values(
log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
cm_with_defaults_override = ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"},
)


def test_init_with_default_values(
cm_with_defaults: ConfigManager, config_path: str, schema_path: str
):
"""
Test that the ConfigManager initializes with the expected default values.
Args:
cm_with_defaults (ConfigManager): A ConfigManager instance with default values.
config_path (str): The path to the configuration file.
schema_path (str): The path to the schema file.
"""
config_response = cm_with_defaults.get_config()
# assert config response
assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1"
assert (
config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1"
)
assert config_response.api_keys == ["OPENAI_API_KEY"]
assert config_response.fields == {
"bedrock-chat:anthropic.claude-v1": {
"credentials_profile_name": "default",
"region_name": "us-west-2",
}
kwargs = {
**common_cm_kwargs,
"defaults" : {
"model_provider_id": "bedrock-chat:anthropic.claude-v2"
},
}

del cm_with_defaults

log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
cm_with_defaults_override = ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"},
)
cm_with_defaults_override = ConfigManager(**kwargs)

assert (
cm_with_defaults_override.get_config().model_provider_id
== "bedrock-chat:anthropic.claude-v2"
== "bedrock-chat:anthropic.claude-v1"
)


Expand Down
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/workspace.code-workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"folders": [
{
"path": "../../.."
}
]
}

0 comments on commit 4d57c31

Please sign in to comment.