Skip to content

Commit

Permalink
add failing test that asserts fields are included in lm_provider_params
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Dec 2, 2024
1 parent 922712c commit d79c85c
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,28 @@ def configure_to_openai(cm: ConfigManager):
cm.update_config(req)
return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS

def configure_with_fields(cm: ConfigManager):
"""
Configures the ConfigManager with fields and API keys.
Returns the expected result of `cm.lm_provider_params`.
"""
req = UpdateConfigRequest(
model_provider_id="openai-chat:gpt-4o",
api_keys={
"OPENAI_API_KEY": "foobar"
},
fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
}
)
cm.update_config(req)
return {
"model_id": "gpt-4o",
"openai_api_key": "foobar",
"openai_api_base": "https://example.com"
}

def test_snapshot_default_config(cm: ConfigManager, snapshot):
config_from_cm: DescribeConfigResponse = cm.get_config()
Expand Down Expand Up @@ -402,3 +424,7 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
config_desc = cm_with_bad_provider_ids.get_config()
assert config_desc.model_provider_id is None
assert config_desc.embeddings_provider_id is None

def test_config_manager_returns_fields(cm):
expected_model_args = configure_with_fields(cm)
assert cm.lm_provider_params == expected_model_args

0 comments on commit d79c85c

Please sign in to comment.