Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
dlqqq committed Nov 27, 2024
1 parent 7b427df commit 27116fc
Showing 2 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
@@ -3,9 +3,9 @@
import os
import shutil
import time
from copy import deepcopy
from typing import List, Optional, Type, Union

from copy import deepcopy
from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
31 changes: 16 additions & 15 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ def config_path(jp_data_dir):
def schema_path(jp_data_dir):
return str(jp_data_dir / "config_schema.json")


@pytest.fixture
def config_file_with_model_fields(jp_data_dir):
"""
@@ -33,21 +34,16 @@ def config_file_with_model_fields(jp_data_dir):
config_data = {
"model_provider_id:": "openai-chat:gpt-4o",
"embeddings_provider_id": None,
"api_keys": {
"openai_api_key": "foobar"
},
"api_keys": {"openai_api_key": "foobar"},
"send_with_shift_enter": False,
"fields": {
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com"
}
},
"fields": {"openai-chat:gpt-4o": {"openai_api_base": "https://example.com"}},
}
config_path = jp_data_dir / "config.json"
with open(config_path, "w") as file:
json.dump(config_data, file)
return str(config_path)


@pytest.fixture
def common_cm_kwargs(config_path, schema_path):
"""Kwargs that are commonly used when initializing the CM."""
@@ -197,29 +193,29 @@ def configure_to_openai(cm: ConfigManager):
cm.update_config(req)
return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS


def configure_with_fields(cm: ConfigManager):
"""
Configures the ConfigManager with fields and API keys.
Returns the expected result of `cm.lm_provider_params`.
"""
req = UpdateConfigRequest(
model_provider_id="openai-chat:gpt-4o",
api_keys={
"OPENAI_API_KEY": "foobar"
},
api_keys={"OPENAI_API_KEY": "foobar"},
fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
}
},
)
cm.update_config(req)
return {
"model_id": "gpt-4o",
"openai_api_key": "foobar",
"openai_api_base": "https://example.com"
"openai_api_base": "https://example.com",
}


def test_snapshot_default_config(cm: ConfigManager, snapshot):
config_from_cm: DescribeConfigResponse = cm.get_config()
assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read")
@@ -448,6 +444,7 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
assert config_desc.model_provider_id is None
assert config_desc.embeddings_provider_id is None


def test_config_manager_returns_fields(cm):
"""
Asserts that `ConfigManager.lm_provider_params` returns model fields set by
@@ -456,12 +453,16 @@ def test_config_manager_returns_fields(cm):
expected_model_args = configure_with_fields(cm)
assert cm.lm_provider_params == expected_model_args

def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields, schema_path):

def test_config_manager_does_not_write_to_defaults(
config_file_with_model_fields, schema_path
):
"""
Asserts that `ConfigManager` does not write to the `defaults` argument when
the configured chat model differs from the one specified in `defaults`.
"""
from copy import deepcopy

config_path = config_file_with_model_fields
log = logging.getLogger()
lm_providers = get_lm_providers()
@@ -471,7 +472,7 @@ def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields
"model_provider_id": None,
"embeddings_provider_id": None,
"api_keys": {},
"fields": {}
"fields": {},
}
expected_defaults = deepcopy(defaults)

0 comments on commit 27116fc

Please sign in to comment.