Skip to content

Commit

Permalink
Setting default model providers
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-khatria committed Jan 23, 2024
1 parent 752e169 commit d198e11
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 2 deletions.
22 changes: 20 additions & 2 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,21 @@ def __init__(
blocked_providers: Optional[List[str]],
allowed_models: Optional[List[str]],
blocked_models: Optional[List[str]],
restrictions: ProviderRestrictions,
provider_defaults: dict,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.log = log
"""List of LM providers."""
self._lm_providers = lm_providers
"""List of EM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions
"""Provider defaults."""
self._provider_defaults = provider_defaults

self._lm_providers = lm_providers
"""List of LM providers."""
Expand Down Expand Up @@ -146,6 +156,7 @@ def _init_validator(self) -> Validator:
self.validator = Validator(schema)

def _init_config(self):
default_dict = self._init_defaults()
if os.path.exists(self.config_path):
self._process_existing_config()
else:
Expand Down Expand Up @@ -195,11 +206,18 @@ def _validate_lm_em_id(self, config):
def _create_default_config(self):
properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
properties = self.validator.schema.get("properties", {})
field_dict = {
field: properties.get(field).get("default") for field in field_list
}
default_config = GlobalConfig(**field_dict)
self._write_config(default_config)
if self._provider_defaults is None:
return field_dict

for field in field_list:
default_value = self._provider_defaults.get(field)
if default_value is not None:
field_dict[field] = default_value
return field_dict

def _read_config(self) -> GlobalConfig:
"""Returns the user's current configuration as a GlobalConfig object.
Expand Down
40 changes: 40 additions & 0 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,36 @@ class AiExtension(ExtensionApp):
config=True,
)

model_provider_id = Unicode(
default_value=None,
allow_none=True,
help="Default language model provider.",
config=True,
)

embeddings_provider_id = Unicode(
default_value=None,
allow_none=True,
help="Default embeddings model provider.",
config=True,
)

api_keys = Dict(
Unicode(),
Unicode(),
default_value=None,
allow_none=True,
help="API keys for language model providers.",
config=True,
)

fields = Dict(
default_value=None,
allow_none=True,
help="Sub fields required for language model providers.",
config=True,
)

def initialize_settings(self):
start = time.time()

Expand All @@ -124,6 +154,14 @@ def initialize_settings(self):
self.settings["model_parameters"] = self.model_parameters
self.log.info(f"Configured model parameters: {self.model_parameters}")


provider_defaults = {
"model_provider_id": self.model_provider_id,
"embeddings_provider_id": self.embeddings_provider_id,
"api_keys": self.api_keys,
"fields": self.fields,
}

# Fetch LM & EM providers
self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
Expand All @@ -142,6 +180,8 @@ def initialize_settings(self):
blocked_providers=self.blocked_providers,
allowed_models=self.allowed_models,
blocked_models=self.blocked_models,
restrictions=restrictions,
provider_defaults=provider_defaults,
)

self.log.info("Registered providers.")
Expand Down
67 changes: 67 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,35 @@ def common_cm_kwargs(config_path, schema_path):
"blocked_providers": None,
"allowed_models": None,
"blocked_models": None,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"provider_defaults": {
"model_provider_id": None,
"embeddings_provider_id": None,
"api_keys": None,
"fields": None,
}
}


@pytest.fixture
def cm_kargs_with_defaults(config_path, schema_path):
"""Kwargs that are commonly used when initializing the CM."""
log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
return {
"log": log,
"lm_providers": lm_providers,
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"provider_defaults": {
"model_provider_id": "bedrock-chat:anthropic.claude-v1",
"embeddings_provider_id": "bedrock:amazon.titan-embed-text-v1",
"api_keys": {"OPENAI_API_KEY": "open-ai-key-value"},
"fields": {"bedrock-chat:anthropic.claude-v1":{"credentials_profile_name": "default","region_name": "us-west-2"}},
}
}


Expand Down Expand Up @@ -69,6 +98,10 @@ def cm_with_allowlists(common_cm_kwargs):
}
return ConfigManager(**kwargs)

def cm_with_defaults(cm_kargs_with_defaults):
"""The default ConfigManager instance, with an empty config and config schema."""
return ConfigManager(**cm_kargs_with_defaults)


@pytest.fixture(autouse=True)
def reset(config_path, schema_path):
Expand Down Expand Up @@ -183,6 +216,40 @@ def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs):
assert test_cm.lm_gid == None
assert test_cm.em_gid == None

def test_init_with_default_values(
cm_with_defaults: ConfigManager, config_path: str, schema_path: str
):
"""
Test that the ConfigManager initializes with the expected default values.
Args:
cm_with_defaults (ConfigManager): A ConfigManager instance with default values.
config_path (str): The path to the configuration file.
schema_path (str): The path to the schema file.
"""
config_response = cm_with_defaults.get_config()
#assert config response
assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1"
assert config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1"
assert config_response.api_keys == ["OPENAI_API_KEY"]
assert config_response.fields == {"bedrock-chat:anthropic.claude-v1":{"credentials_profile_name": "default","region_name": "us-west-2"}}

del cm_with_defaults

log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
cm_with_defaults_override =ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"},
)


def test_init_with_default_values(
cm_with_defaults: ConfigManager, config_path: str, schema_path: str
):
Expand Down

0 comments on commit d198e11

Please sign in to comment.