Skip to content

Commit

Permalink
add test capturing bug introduced by #421
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Dec 2, 2024
1 parent 5f68261 commit f82785e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 4 deletions.
8 changes: 4 additions & 4 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
allowed_providers: Optional[List[str]],
blocked_providers: Optional[List[str]],
allowed_models: Optional[List[str]],
blocked_models: Optional[List[str]],
defaults: dict,
allowed_providers: Optional[List[str]] = None,
blocked_providers: Optional[List[str]] = None,
allowed_models: Optional[List[str]] = None,
blocked_models: Optional[List[str]] = None,
*args,
**kwargs,
):
Expand Down
57 changes: 57 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ def config_path(jp_data_dir):
def schema_path(jp_data_dir):
return str(jp_data_dir / "config_schema.json")

@pytest.fixture
def config_file_with_model_fields(jp_data_dir):
"""
Fixture that creates a `config.json` file with the chat model set to
`openai-chat:gpt-4o` and fields for that model. Returns path to the file.
"""
config_data = {
"model_provider_id:": "openai-chat:gpt-4o",
"embeddings_provider_id": None,
"api_keys": {
"openai_api_key": "foobar"
},
"send_with_shift_enter": False,
"fields": {
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com"
}
},
}
config_path = jp_data_dir / "config.json"
with open(config_path, "w") as file:
json.dump(config_data, file)
return str(config_path)

@pytest.fixture
def common_cm_kwargs(config_path, schema_path):
Expand Down Expand Up @@ -426,5 +449,39 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
assert config_desc.embeddings_provider_id is None

def test_config_manager_returns_fields(cm):
"""
Asserts that `ConfigManager.lm_provider_params` returns model fields set by
the user.
"""
expected_model_args = configure_with_fields(cm)
assert cm.lm_provider_params == expected_model_args

def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields, schema_path):
"""
Asserts that `ConfigManager` does not write to the `defaults` argument when
the configured chat model differs from the one specified in `defaults`.
"""
from copy import deepcopy
config_path = config_file_with_model_fields
log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()

defaults = {
"model_provider_id": None,
"embeddings_provider_id": None,
"api_keys": {},
"fields": {}
}
expected_defaults = deepcopy(defaults)

cm = ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
defaults=defaults,
)

assert defaults == expected_defaults

0 comments on commit f82785e

Please sign in to comment.