diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 82ef03126..8d2812510 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -105,11 +105,21 @@ def __init__( blocked_providers: Optional[List[str]], allowed_models: Optional[List[str]], blocked_models: Optional[List[str]], + restrictions: ProviderRestrictions, + provider_defaults: dict, *args, **kwargs, ): super().__init__(*args, **kwargs) self.log = log + """List of LM providers.""" + self._lm_providers = lm_providers + """List of EM providers.""" + self._em_providers = em_providers + """Provider restrictions.""" + self._restrictions = restrictions + """Provider defaults.""" + self._provider_defaults = provider_defaults self._lm_providers = lm_providers """List of LM providers.""" @@ -146,6 +156,7 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): + default_dict = self._init_defaults() if os.path.exists(self.config_path): self._process_existing_config() else: @@ -195,11 +206,18 @@ def _validate_lm_em_id(self, config): def _create_default_config(self): properties = self.validator.schema.get("properties", {}) field_list = GlobalConfig.__fields__.keys() + properties = self.validator.schema.get("properties", {}) field_dict = { field: properties.get(field).get("default") for field in field_list } - default_config = GlobalConfig(**field_dict) - self._write_config(default_config) + if self._provider_defaults is None: + return field_dict + + for field in field_list: + default_value = self._provider_defaults.get(field) + if default_value is not None: + field_dict[field] = default_value + return field_dict def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index e3958fc7b..f4f4ecba3 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -106,6 +106,36 @@ class AiExtension(ExtensionApp): config=True, ) + model_provider_id = Unicode( + default_value=None, + allow_none=True, + help="Default language model provider.", + config=True, + ) + + embeddings_provider_id = Unicode( + default_value=None, + allow_none=True, + help="Default embeddings model provider.", + config=True, + ) + + api_keys = Dict( + Unicode(), + Unicode(), + default_value=None, + allow_none=True, + help="API keys for language model providers.", + config=True, + ) + + fields = Dict( + default_value=None, + allow_none=True, + help="Sub fields required for language model providers.", + config=True, + ) + def initialize_settings(self): start = time.time() @@ -124,6 +154,14 @@ def initialize_settings(self): self.settings["model_parameters"] = self.model_parameters self.log.info(f"Configured model parameters: {self.model_parameters}") + + provider_defaults = { + "model_provider_id": self.model_provider_id, + "embeddings_provider_id": self.embeddings_provider_id, + "api_keys": self.api_keys, + "fields": self.fields, + } + # Fetch LM & EM providers self.settings["lm_providers"] = get_lm_providers( log=self.log, restrictions=restrictions @@ -142,6 +180,8 @@ def initialize_settings(self): blocked_providers=self.blocked_providers, allowed_models=self.allowed_models, blocked_models=self.blocked_models, + restrictions=restrictions, + provider_defaults=provider_defaults, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index eeddfff88..7a0d3544f 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -41,6 +41,35 @@ def common_cm_kwargs(config_path, schema_path): "blocked_providers": None, "allowed_models": None, "blocked_models": None, + "restrictions": {"allowed_providers": None, "blocked_providers": None}, + "provider_defaults": { + "model_provider_id": None, + "embeddings_provider_id": None, + "api_keys": None, + "fields": None, + } + } + + +@pytest.fixture +def cm_kargs_with_defaults(config_path, schema_path): + """Kwargs that are commonly used when initializing the CM.""" + log = logging.getLogger() + lm_providers = get_lm_providers() + em_providers = get_em_providers() + return { + "log": log, + "lm_providers": lm_providers, + "em_providers": em_providers, + "config_path": config_path, + "schema_path": schema_path, + "restrictions": {"allowed_providers": None, "blocked_providers": None}, + "provider_defaults": { + "model_provider_id": "bedrock-chat:anthropic.claude-v1", + "embeddings_provider_id": "bedrock:amazon.titan-embed-text-v1", + "api_keys": {"OPENAI_API_KEY": "open-ai-key-value"}, + "fields": {"bedrock-chat:anthropic.claude-v1":{"credentials_profile_name": "default","region_name": "us-west-2"}}, + } } @@ -69,6 +98,10 @@ def cm_with_allowlists(common_cm_kwargs): } return ConfigManager(**kwargs) +def cm_with_defaults(cm_kargs_with_defaults): + """The default ConfigManager instance, with an empty config and config schema.""" + return ConfigManager(**cm_kargs_with_defaults) + @pytest.fixture(autouse=True) def reset(config_path, schema_path): @@ -183,6 +216,40 @@ def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs): assert test_cm.lm_gid == None assert test_cm.em_gid == None +def test_init_with_default_values( + cm_with_defaults: ConfigManager, config_path: str, schema_path: str +): + """ + Test that the ConfigManager initializes with the expected default values. + + Args: + cm_with_defaults (ConfigManager): A ConfigManager instance with default values. + config_path (str): The path to the configuration file. + schema_path (str): The path to the schema file. + """ + config_response = cm_with_defaults.get_config() + #assert config response + assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1" + assert config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1" + assert config_response.api_keys == ["OPENAI_API_KEY"] + assert config_response.fields == {"bedrock-chat:anthropic.claude-v1":{"credentials_profile_name": "default","region_name": "us-west-2"}} + + del cm_with_defaults + + log = logging.getLogger() + lm_providers = get_lm_providers() + em_providers = get_em_providers() + cm_with_defaults_override =ConfigManager( + log=log, + lm_providers=lm_providers, + em_providers=em_providers, + config_path=config_path, + schema_path=schema_path, + restrictions={"allowed_providers": None, "blocked_providers": None}, + provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"}, + ) + + def test_init_with_default_values( cm_with_defaults: ConfigManager, config_path: str, schema_path: str ):