From 76ff585caef8c56f873530523fe4f9c195ce5831 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Fri, 10 Nov 2023 18:24:29 -0800 Subject: [PATCH] fix tests --- .../jupyter_ai/tests/test_config_manager.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 63f6fffc8..0d109ca09 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -10,7 +10,7 @@ KeyInUseError, WriteConflictError, ) -from jupyter_ai.models import DescribeConfigResponse, UpdateConfigRequest +from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from pydantic import ValidationError @@ -85,8 +85,25 @@ def reset(config_path, schema_path): @pytest.fixture -def config_with_bad_provider_ids(): - return json.dumps({"model_provider_id:": "foo", "embeddings_provider_id": "bar"}) +def config_with_bad_provider_ids(tmp_path): + """Fixture that creates a `config.json` file with bad provider ids in `tmp_path` folder and returns path to the file.""" + config_data = { + "model_provider_id:": "foo:bar", + "embeddings_provider_id": "buzz:fizz", + "api_keys": {}, + "send_with_shift_enter": False, + "fields": {}, + } + config_path = tmp_path / "config.json" + with open(config_path, "w") as file: + json.dump(config_data, file) + return str(config_path) + + +@pytest.fixture +def cm_with_bad_provider_ids(common_cm_kwargs, config_with_bad_provider_ids): + common_cm_kwargs["config_path"] = config_with_bad_provider_ids + return ConfigManager(**common_cm_kwargs) def configure_to_cohere(cm: ConfigManager): @@ -298,9 +315,7 @@ def test_forbid_deleting_key_in_use(cm: ConfigManager): cm.delete_api_key("COHERE_API_KEY") -def test_handle_bad_provider_ids(config_with_bad_provider_ids, common_cm_kwargs): - with patch("builtins.open", mock_open(read_data=config_with_bad_provider_ids)): - cm = ConfigManager(**common_cm_kwargs) - config_desc = cm.get_config() - assert config_desc.model_provider_id is None - assert config_desc.embeddings_provider_id is None +def test_handle_bad_provider_ids(cm_with_bad_provider_ids): + config_desc = cm_with_bad_provider_ids.get_config() + assert config_desc.model_provider_id is None + assert config_desc.embeddings_provider_id is None