Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Nov 11, 2023
1 parent a5a79d6 commit 76ff585
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
KeyInUseError,
WriteConflictError,
)
from jupyter_ai.models import DescribeConfigResponse, UpdateConfigRequest
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from pydantic import ValidationError

Expand Down Expand Up @@ -85,8 +85,25 @@ def reset(config_path, schema_path):


@pytest.fixture
def config_with_bad_provider_ids():
return json.dumps({"model_provider_id:": "foo", "embeddings_provider_id": "bar"})
def config_with_bad_provider_ids(tmp_path):
"""Fixture that creates a `config.json` file with bad provider ids in `tmp_path` folder and returns path to the file."""
config_data = {
"model_provider_id:": "foo:bar",
"embeddings_provider_id": "buzz:fizz",
"api_keys": {},
"send_with_shift_enter": False,
"fields": {},
}
config_path = tmp_path / "config.json"
with open(config_path, "w") as file:
json.dump(config_data, file)
return str(config_path)


@pytest.fixture
def cm_with_bad_provider_ids(common_cm_kwargs, config_with_bad_provider_ids):
common_cm_kwargs["config_path"] = config_with_bad_provider_ids
return ConfigManager(**common_cm_kwargs)


def configure_to_cohere(cm: ConfigManager):
Expand Down Expand Up @@ -298,9 +315,7 @@ def test_forbid_deleting_key_in_use(cm: ConfigManager):
cm.delete_api_key("COHERE_API_KEY")


def test_handle_bad_provider_ids(config_with_bad_provider_ids, common_cm_kwargs):
with patch("builtins.open", mock_open(read_data=config_with_bad_provider_ids)):
cm = ConfigManager(**common_cm_kwargs)
config_desc = cm.get_config()
assert config_desc.model_provider_id is None
assert config_desc.embeddings_provider_id is None
def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
config_desc = cm_with_bad_provider_ids.get_config()
assert config_desc.model_provider_id is None
assert config_desc.embeddings_provider_id is None

0 comments on commit 76ff585

Please sign in to comment.