Skip to content

Commit

Permalink
Distinguish between completion and chat models
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Apr 5, 2024
1 parent 3bfce32 commit 2067552
Show file tree
Hide file tree
Showing 12 changed files with 706 additions and 280 deletions.
10 changes: 10 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ class Config:
provider is selected.
"""

@classmethod
def chat_models(self):
"""Models which are suitable for chat."""
return self.models

@classmethod
def completion_models(self):
"""Models which are suitable for completions."""
return self.models

#
# instance attrs
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self, *args, **kwargs):
self.llm_chain = None

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
lm_provider = self.config_manager.completions_lm_provider
lm_provider_params = self.config_manager.completions_lm_provider_params

if not lm_provider or not lm_provider_params:
return None
Expand Down
23 changes: 23 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
"default": null,
"readOnly": false
},
"completions_model_provider_id": {
"$comment": "Language model global ID for completions.",
"type": ["string", "null"],
"default": null,
"readOnly": false
},
"completions_embeddings_provider_id": {
"$comment": "Embedding model global ID for completions.",
"type": ["string", "null"],
"default": null,
"readOnly": false
},
"api_keys": {
"$comment": "Dictionary of API keys, mapping key names to key values.",
"type": "object",
Expand All @@ -37,6 +49,17 @@
}
},
"additionalProperties": false
},
"completions_fields": {
"$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.",
"type": "object",
"default": {},
"patternProperties": {
"^.*$": {
"anyOf": [{ "type": "object" }]
}
},
"additionalProperties": false
}
}
}
124 changes: 67 additions & 57 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,41 +164,48 @@ def _process_existing_config(self, default_config):
{k: v for k, v in existing_config.items() if v is not None},
)
config = GlobalConfig(**merged_config)
validated_config = self._validate_lm_em_id(config)
validated_config = self._validate_lm_em_id(
config, lm_key="model_provider_id", em_key="embeddings_provider_id"
)
validated_config = self._validate_lm_em_id(
config,
lm_key="completions_model_provider_id",
em_key="completions_embeddings_provider_id",
)

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(validated_config)

def _validate_lm_em_id(self, config):
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id
def _validate_lm_em_id(self, config, lm_key, em_key):
lm_id = getattr(config, lm_key)
em_id = getattr(config, em_key)

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
setattr(config, lm_key, None)
if em_id is not None and not self._validate_model(em_id, raise_exc=False):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None
setattr(config, em_key, None)

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
setattr(config, lm_key, None)
if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None
setattr(config, em_key, None)

return config

Expand Down Expand Up @@ -321,28 +328,29 @@ def _write_config(self, new_config: GlobalConfig):
complete `GlobalConfig` object, and should not be called publicly."""
# remove any empty field dictionaries
new_config.fields = {k: v for k, v in new_config.fields.items() if v}
new_config.completions_fields = {
k: v for k, v in new_config.completions_fields.items() if v
}

self._validate_config(new_config)
with open(self.config_path, "w") as f:
json.dump(new_config.dict(), f, indent=self.indentation_depth)

def delete_api_key(self, key_name: str):
config_dict = self._read_config().dict()
lm_provider = self.lm_provider
em_provider = self.em_provider
required_keys = []
if (
lm_provider
and lm_provider.auth_strategy
and lm_provider.auth_strategy.type == "env"
):
required_keys.append(lm_provider.auth_strategy.name)
if (
em_provider
and em_provider.auth_strategy
and em_provider.auth_strategy.type == "env"
):
required_keys.append(self.em_provider.auth_strategy.name)
for provider in [
self.lm_provider,
self.em_provider,
self.completions_lm_provider,
self.completions_em_provider,
]:
if (
provider
and provider.auth_strategy
and provider.auth_strategy.type == "env"
):
required_keys.append(provider.auth_strategy.name)

if key_name in required_keys:
raise KeyInUseError(
Expand Down Expand Up @@ -390,67 +398,69 @@ def em_gid(self):

@property
def lm_provider(self):
config = self._read_config()
lm_gid = config.model_provider_id
if lm_gid is None:
return None

_, Provider = get_lm_provider(config.model_provider_id, self._lm_providers)
return Provider
return self._get_provider("model_provider_id", self._lm_providers)

@property
def em_provider(self):
return self._get_provider("embeddings_provider_id", self._em_providers)

@property
def completions_lm_provider(self):
return self._get_provider("completions_model_provider_id", self._lm_providers)

@property
def completions_em_provider(self):
return self._get_provider(
"completions_embeddings_provider_id", self._em_providers
)

def _get_provider(self, key, listing):
config = self._read_config()
em_gid = config.embeddings_provider_id
if em_gid is None:
gid = getattr(config, key)
if gid is None:
return None

_, Provider = get_em_provider(em_gid, self._em_providers)
_, Provider = get_lm_provider(gid, listing)
return Provider

@property
def lm_provider_params(self):
# get generic fields
config = self._read_config()
lm_gid = config.model_provider_id
if not lm_gid:
return None
return self._provider_params("model_provider_id", self._lm_providers)

lm_lid = lm_gid.split(":", 1)[1]
fields = config.fields.get(lm_gid, {})

# get authn fields
_, Provider = get_lm_provider(lm_gid, self._lm_providers)
authn_fields = {}
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
key_name = Provider.auth_strategy.name
authn_fields[key_name.lower()] = config.api_keys[key_name]
@property
def em_provider_params(self):
return self._provider_params("embeddings_provider_id", self._em_providers)

return {
"model_id": lm_lid,
**fields,
**authn_fields,
}
@property
def completions_lm_provider_params(self):
return self._provider_params(
"completions_model_provider_id", self._lm_providers
)

@property
def em_provider_params(self):
def completions_em_provider_params(self):
return self._provider_params(
"completions_embeddings_provider_id", self._em_providers
)

def _provider_params(self, key, listing):
# get generic fields
config = self._read_config()
em_gid = config.embeddings_provider_id
if not em_gid:
gid = getattr(config, key)
if not gid:
return None

em_lid = em_gid.split(":", 1)[1]
lid = gid.split(":", 1)[1]

# get authn fields
_, Provider = get_em_provider(em_gid, self._em_providers)
_, Provider = get_em_provider(gid, listing)
authn_fields = {}
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
key_name = Provider.auth_strategy.name
authn_fields[key_name.lower()] = config.api_keys[key_name]

return {
"model_id": em_lid,
"model_id": lid,
**authn_fields,
}

Expand Down
6 changes: 6 additions & 0 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def filter_predicate(local_model_id: str):
# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
provider.models = list(filter(filter_predicate, provider.models))
provider.chat_models = list(filter(filter_predicate, provider.chat_models))
provider.completion_models = list(
filter(filter_predicate, provider.completion_models)
)

# filter out every provider with no models which satisfy the allow/blocklist, then return
return filter((lambda p: len(p.models) > 0), providers)
Expand All @@ -311,6 +315,8 @@ def get(self):
id=provider.id,
name=provider.name,
models=provider.models,
chat_models=provider.chat_models(),
completion_models=provider.completion_models(),
help=provider.help,
auth_strategy=provider.auth_strategy,
registry=provider.registry,
Expand Down
11 changes: 11 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class ListProvidersEntry(BaseModel):
auth_strategy: AuthStrategy
registry: bool
fields: List[Field]
chat_models: Optional[List[str]]
completion_models: Optional[List[str]]


class ListProvidersResponse(BaseModel):
Expand Down Expand Up @@ -121,6 +123,9 @@ class DescribeConfigResponse(BaseModel):
# timestamp indicating when the configuration file was last read. should be
# passed to the subsequent UpdateConfig request.
last_read: int
completions_model_provider_id: Optional[str]
completions_embeddings_provider_id: Optional[str]
completions_fields: Dict[str, Dict[str, Any]]


def forbid_none(cls, v):
Expand All @@ -137,6 +142,9 @@ class UpdateConfigRequest(BaseModel):
# if passed, this will raise an Error if the config was written to after the
# time specified by `last_read` to prevent write-write conflicts.
last_read: Optional[int]
completions_model_provider_id: Optional[str]
completions_embeddings_provider_id: Optional[str]
completions_fields: Optional[Dict[str, Dict[str, Any]]]

_validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)(
forbid_none
Expand All @@ -154,3 +162,6 @@ class GlobalConfig(BaseModel):
send_with_shift_enter: bool
fields: Dict[str, Dict[str, Any]]
api_keys: Dict[str, str]
completions_model_provider_id: Optional[str]
completions_embeddings_provider_id: Optional[str]
completions_fields: Dict[str, Dict[str, Any]]
Loading

0 comments on commit 2067552

Please sign in to comment.