diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index e8fea6306..608bea77f 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -107,11 +107,11 @@ def __init__( log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, - allowed_providers: Optional[List[str]], - blocked_providers: Optional[List[str]], - allowed_models: Optional[List[str]], - blocked_models: Optional[List[str]], defaults: dict, + allowed_providers: Optional[List[str]] = None, + blocked_providers: Optional[List[str]] = None, + allowed_models: Optional[List[str]] = None, + blocked_models: Optional[List[str]] = None, *args, **kwargs, ): diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 437333960..c807cf91b 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -24,6 +24,29 @@ def config_path(jp_data_dir): def schema_path(jp_data_dir): return str(jp_data_dir / "config_schema.json") +@pytest.fixture +def config_file_with_model_fields(jp_data_dir): + """ + Fixture that creates a `config.json` file with the chat model set to + `openai-chat:gpt-4o` and fields for that model. Returns path to the file. + """ + config_data = { + "model_provider_id:": "openai-chat:gpt-4o", + "embeddings_provider_id": None, + "api_keys": { + "openai_api_key": "foobar" + }, + "send_with_shift_enter": False, + "fields": { + "openai-chat:gpt-4o": { + "openai_api_base": "https://example.com" + } + }, + } + config_path = jp_data_dir / "config.json" + with open(config_path, "w") as file: + json.dump(config_data, file) + return str(config_path) @pytest.fixture def common_cm_kwargs(config_path, schema_path): @@ -426,5 +449,39 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids): assert config_desc.embeddings_provider_id is None def test_config_manager_returns_fields(cm): + """ + Asserts that `ConfigManager.lm_provider_params` returns model fields set by + the user. + """ expected_model_args = configure_with_fields(cm) assert cm.lm_provider_params == expected_model_args + +def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields, schema_path): + """ + Asserts that `ConfigManager` does not write to the `defaults` argument when + the configured chat model differs from the one specified in `defaults`. + """ + from copy import deepcopy + config_path = config_file_with_model_fields + log = logging.getLogger() + lm_providers = get_lm_providers() + em_providers = get_em_providers() + + defaults = { + "model_provider_id": None, + "embeddings_provider_id": None, + "api_keys": {}, + "fields": {} + } + expected_defaults = deepcopy(defaults) + + cm = ConfigManager( + log=log, + lm_providers=lm_providers, + em_providers=em_providers, + config_path=config_path, + schema_path=schema_path, + defaults=defaults, + ) + + assert defaults == expected_defaults