diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 02714d3a7..766bee8bd 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -770,9 +770,9 @@ a blocklist, to block some providers. This configuration allows for setting a default model, and embedding provider with their corresponding API keys. Following command line arguments can be used to initialize Jupyter AI extension with default providers: -1. ```--AiExtension.default_model_provider_id```: Specify the default LLM model. E.g.: ```--AiExtension.default_model_provider_id="bedrock-chat:anthropic.claude-v2"``` -2. ```--AiExtension.default_embeddings_provider_id```: Specify default embedding provider. E.g.: ```--AiExtension.default_embeddings_provider_id="bedrock:amazon.titan-embed-text-v1"``` -3. ```--AiExtension.default_api_keys```: Specify model provider keys in a JSON format. E.g. ```--AiExtension.default_api_keys='{"OPENAI_API_KEY": "sk-abcd}'``` +1. ```--AiExtension.default_language_model```: Specify the default LLM model. Sample value: ```--AiExtension.default_language_model="bedrock-chat:anthropic.claude-v2"``` +2. ```--AiExtension.default_embedding_model```: Specify default embedding model. Sample value: ```--AiExtension.default_embedding_model="bedrock:amazon.titan-embed-text-v1"``` +3. ```--AiExtension.api_keys```: Specify model keys in a JSON format. Sample CLI argument: ```--AiExtension.api_keys='{"OPENAI_API_KEY": "sk-abcd}'``` ### Blocklisting providers diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index b093dfd7e..a786419f8 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -105,7 +105,7 @@ def __init__( blocked_providers: Optional[List[str]], allowed_models: Optional[List[str]], blocked_models: Optional[List[str]], - provider_defaults: dict, + defaults: dict, *args, **kwargs, ): @@ -121,7 +121,7 @@ def __init__( self._blocked_providers = blocked_providers self._allowed_models = allowed_models self._blocked_models = blocked_models - self._provider_defaults = provider_defaults + self._defaults = defaults """Provider defaults.""" self._last_read: Optional[int] = None @@ -158,7 +158,8 @@ def _init_config(self): def _process_existing_config(self, default_config): with open(self.config_path, encoding="utf-8") as f: existing_config = json.loads(f.read()) - merged_config = {**existing_config, **default_config} + merged_config = Merger.merge(default_config, + {k: v for k, v in existing_config.items() if v is not None}) config = GlobalConfig(**merged_config) validated_config = self._validate_lm_em_id(config) @@ -207,11 +208,11 @@ def _init_defaults(self): field_dict = { field: properties.get(field).get("default") for field in field_list } - if self._provider_defaults is None: + if self._defaults is None: return field_dict for field in field_list: - default_value = self._provider_defaults.get(field) + default_value = self._defaults.get(field) if default_value is not None: field_dict[field] = default_value return field_dict diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 69313f835..5c5b3bf41 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -86,12 +86,13 @@ class AiExtension(ExtensionApp): config=True, ) - default_api_keys = Dict( + api_keys = Dict( key_trait=Unicode(), value_trait=Unicode(), default_value=None, allow_none=True, - help="API keys for language model providers.", + help="""API keys for model providers, as a dictionary, in the format + `:`. Defaults to None.""", config=True, ) @@ -115,17 +116,19 @@ class AiExtension(ExtensionApp): config=True, ) - default_model_provider_id = Unicode( + default_language_model = Unicode( default_value=None, allow_none=True, - help="Default language model provider.", + help="""Default language model to use, as string in the format + :, defaults to None.""", config=True, ) - default_embeddings_provider_id = Unicode( + default_embedding_model = Unicode( default_value=None, allow_none=True, - help="Default embeddings model provider.", + help="""Default embedding model to use, as string in the format + :, defaults to None.""", config=True, ) @@ -147,10 +150,10 @@ def initialize_settings(self): self.settings["model_parameters"] = self.model_parameters self.log.info(f"Configured model parameters: {self.model_parameters}") - provider_defaults = { - "model_provider_id": self.default_model_provider_id, - "embeddings_provider_id": self.default_embeddings_provider_id, - "api_keys": self.default_api_keys, + defaults = { + "model_provider_id": self.default_language_model, + "embeddings_provider_id": self.default_embedding_model, + "api_keys": self.api_keys, "fields": self.model_parameters, } @@ -172,7 +175,7 @@ def initialize_settings(self): blocked_providers=self.blocked_providers, allowed_models=self.allowed_models, blocked_models=self.blocked_models, - provider_defaults=provider_defaults, + defaults=defaults, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index ce2ae5786..b6587e741 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -42,7 +42,7 @@ def common_cm_kwargs(config_path, schema_path): "allowed_models": None, "blocked_models": None, "restrictions": {"allowed_providers": None, "blocked_providers": None}, - "provider_defaults": { + "defaults": { "model_provider_id": None, "embeddings_provider_id": None, "api_keys": None, @@ -52,19 +52,14 @@ def common_cm_kwargs(config_path, schema_path): @pytest.fixture -def cm_kargs_with_defaults(config_path, schema_path): +def cm_kargs_with_defaults(config_path, schema_path, common_cm_kwargs): """Kwargs that are commonly used when initializing the CM.""" log = logging.getLogger() lm_providers = get_lm_providers() em_providers = get_em_providers() return { - "log": log, - "lm_providers": lm_providers, - "em_providers": em_providers, - "config_path": config_path, - "schema_path": schema_path, - "restrictions": {"allowed_providers": None, "blocked_providers": None}, - "provider_defaults": { + **common_cm_kwargs, + "defaults": { "model_provider_id": "bedrock-chat:anthropic.claude-v1", "embeddings_provider_id": "bedrock:amazon.titan-embed-text-v1", "api_keys": {"OPENAI_API_KEY": "open-ai-key-value"}, @@ -103,12 +98,7 @@ def cm_with_allowlists(common_cm_kwargs): } return ConfigManager(**kwargs) - -def cm_with_defaults(cm_kargs_with_defaults): - """The default ConfigManager instance, with an empty config and config schema.""" - return ConfigManager(**cm_kargs_with_defaults) - - +@pytest.fixture def cm_with_defaults(cm_kargs_with_defaults): """The default ConfigManager instance, with an empty config and config schema.""" return ConfigManager(**cm_kargs_with_defaults) @@ -229,95 +219,7 @@ def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs): def test_init_with_default_values( - cm_with_defaults: ConfigManager, config_path: str, schema_path: str -): - """ - Test that the ConfigManager initializes with the expected default values. - - Args: - cm_with_defaults (ConfigManager): A ConfigManager instance with default values. - config_path (str): The path to the configuration file. - schema_path (str): The path to the schema file. - """ - config_response = cm_with_defaults.get_config() - # assert config response - assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1" - assert ( - config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1" - ) - assert config_response.api_keys == ["OPENAI_API_KEY"] - assert config_response.fields == { - "bedrock-chat:anthropic.claude-v1": { - "credentials_profile_name": "default", - "region_name": "us-west-2", - } - } - - del cm_with_defaults - - log = logging.getLogger() - lm_providers = get_lm_providers() - em_providers = get_em_providers() - cm_with_defaults_override = ConfigManager( - log=log, - lm_providers=lm_providers, - em_providers=em_providers, - config_path=config_path, - schema_path=schema_path, - restrictions={"allowed_providers": None, "blocked_providers": None}, - provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"}, - ) - - -def test_init_with_default_values( - cm_with_defaults: ConfigManager, config_path: str, schema_path: str -): - """ - Test that the ConfigManager initializes with the expected default values. - - Args: - cm_with_defaults (ConfigManager): A ConfigManager instance with default values. - config_path (str): The path to the configuration file. - schema_path (str): The path to the schema file. - """ - config_response = cm_with_defaults.get_config() - # assert config response - assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1" - assert ( - config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1" - ) - assert config_response.api_keys == ["OPENAI_API_KEY"] - assert config_response.fields == { - "bedrock-chat:anthropic.claude-v1": { - "credentials_profile_name": "default", - "region_name": "us-west-2", - } - } - - del cm_with_defaults - - log = logging.getLogger() - lm_providers = get_lm_providers() - em_providers = get_em_providers() - cm_with_defaults_override = ConfigManager( - log=log, - lm_providers=lm_providers, - em_providers=em_providers, - config_path=config_path, - schema_path=schema_path, - restrictions={"allowed_providers": None, "blocked_providers": None}, - provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"}, - ) - - assert ( - cm_with_defaults_override.get_config().model_provider_id - == "bedrock-chat:anthropic.claude-v2" - ) - - -def test_init_with_default_values( - cm_with_defaults: ConfigManager, config_path: str, schema_path: str -): + cm_with_defaults: ConfigManager, config_path: str, schema_path: str, common_cm_kwargs): """ Test that the ConfigManager initializes with the expected default values. @@ -345,60 +247,17 @@ def test_init_with_default_values( log = logging.getLogger() lm_providers = get_lm_providers() em_providers = get_em_providers() - cm_with_defaults_override = ConfigManager( - log=log, - lm_providers=lm_providers, - em_providers=em_providers, - config_path=config_path, - schema_path=schema_path, - restrictions={"allowed_providers": None, "blocked_providers": None}, - provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"}, - ) - - -def test_init_with_default_values( - cm_with_defaults: ConfigManager, config_path: str, schema_path: str -): - """ - Test that the ConfigManager initializes with the expected default values. - - Args: - cm_with_defaults (ConfigManager): A ConfigManager instance with default values. - config_path (str): The path to the configuration file. - schema_path (str): The path to the schema file. - """ - config_response = cm_with_defaults.get_config() - # assert config response - assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1" - assert ( - config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1" - ) - assert config_response.api_keys == ["OPENAI_API_KEY"] - assert config_response.fields == { - "bedrock-chat:anthropic.claude-v1": { - "credentials_profile_name": "default", - "region_name": "us-west-2", - } + kwargs = { + **common_cm_kwargs, + "defaults" : { + "model_provider_id": "bedrock-chat:anthropic.claude-v2" + }, } - - del cm_with_defaults - - log = logging.getLogger() - lm_providers = get_lm_providers() - em_providers = get_em_providers() - cm_with_defaults_override = ConfigManager( - log=log, - lm_providers=lm_providers, - em_providers=em_providers, - config_path=config_path, - schema_path=schema_path, - restrictions={"allowed_providers": None, "blocked_providers": None}, - provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"}, - ) + cm_with_defaults_override = ConfigManager(**kwargs) assert ( cm_with_defaults_override.get_config().model_provider_id - == "bedrock-chat:anthropic.claude-v2" + == "bedrock-chat:anthropic.claude-v1" ) diff --git a/packages/jupyter-ai/jupyter_ai/workspace.code-workspace b/packages/jupyter-ai/jupyter_ai/workspace.code-workspace new file mode 100644 index 000000000..ec89ccf30 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/workspace.code-workspace @@ -0,0 +1,7 @@ +{ + "folders": [ + { + "path": "../../.." + } + ] +} \ No newline at end of file