Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distinguish between completion and chat models #711

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
213891c
Distinguish between completion and chat models
krassowski Apr 5, 2024
2c196ef
Fix tests
krassowski Apr 3, 2024
578f63c
Shorten the tab name, move settings button
krassowski Apr 8, 2024
575f3f4
Implement the completion model selection in chat UI
krassowski Apr 17, 2024
9b90e4d
Improve docstring
krassowski Apr 17, 2024
9882134
Call `_validate_lm_em_id` only once, add typing annotations
krassowski Apr 17, 2024
ac7cd84
Remove embeddings provider for completions
krassowski Apr 17, 2024
59156f9
Use type alias to reduce changeset/make review easier
krassowski Apr 17, 2024
b06481a
Rename `_validate_lm_em_id` to `_validate_model_ids`
krassowski Apr 17, 2024
c410441
Merge branch 'main' into separate-completer-and-chat-settings
krassowski Apr 25, 2024
ec5b1b8
Rename `LLMHandlerMixin` to `CompletionsModelMixin`
krassowski May 2, 2024
84165e5
Rename "Chat LM" to "LM"; add title attribute; note
krassowski May 2, 2024
41ecf03
Rename heading "Completer model" → "Inline completions model"
krassowski May 2, 2024
8d727d3
Move `UseSignal` down to `CompleterSettingsButton` implementation
krassowski May 2, 2024
a9dc569
Rename the label in the select to "Inline completion model"
krassowski May 2, 2024
13d0ecf
Disable selection when completer is not enabled
krassowski May 2, 2024
92fa8c6
Remove use of `UseSignal`, tweak naming of `useState`
krassowski May 2, 2024
18e57f4
Use mui tooltips
krassowski May 2, 2024
5ac3cf3
Fix use of `jai_config_manager`
krassowski May 2, 2024
b5587fc
Fix tests
krassowski May 2, 2024
32b28b9
Merge branch 'main' into separate-completer-and-chat-settings
krassowski May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ class Config:
provider is selected.
"""

@classmethod
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
def chat_models(self):
"""Models which are suitable for chat."""
return self.models

@classmethod
def completion_models(self):
"""Models which are suitable for completions."""
return self.models

#
# instance attrs
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self, *args, **kwargs):
self.llm_chain = None

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
lm_provider = self.config_manager.completions_lm_provider
lm_provider_params = self.config_manager.completions_lm_provider_params
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
dlqqq marked this conversation as resolved.
Show resolved Hide resolved

if not lm_provider or not lm_provider_params:
return None
Expand Down
23 changes: 23 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
"default": null,
"readOnly": false
},
"completions_model_provider_id": {
"$comment": "Language model global ID for completions.",
"type": ["string", "null"],
"default": null,
"readOnly": false
},
"completions_embeddings_provider_id": {
"$comment": "Embedding model global ID for completions.",
"type": ["string", "null"],
"default": null,
"readOnly": false
},
krassowski marked this conversation as resolved.
Show resolved Hide resolved
"api_keys": {
"$comment": "Dictionary of API keys, mapping key names to key values.",
"type": "object",
Expand All @@ -37,6 +49,17 @@
}
},
"additionalProperties": false
},
"completions_fields": {
"$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.",
"type": "object",
"default": {},
"patternProperties": {
"^.*$": {
"anyOf": [{ "type": "object" }]
}
},
"additionalProperties": false
}
}
}
124 changes: 67 additions & 57 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,41 +164,48 @@ def _process_existing_config(self, default_config):
{k: v for k, v in existing_config.items() if v is not None},
)
config = GlobalConfig(**merged_config)
validated_config = self._validate_lm_em_id(config)
validated_config = self._validate_lm_em_id(
config, lm_key="model_provider_id", em_key="embeddings_provider_id"
)
validated_config = self._validate_lm_em_id(
config,
lm_key="completions_model_provider_id",
em_key="completions_embeddings_provider_id",
)

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(validated_config)

def _validate_lm_em_id(self, config):
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id
def _validate_lm_em_id(self, config, lm_key, em_key):
lm_id = getattr(config, lm_key)
em_id = getattr(config, em_key)

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
setattr(config, lm_key, None)
if em_id is not None and not self._validate_model(em_id, raise_exc=False):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None
setattr(config, em_key, None)

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
setattr(config, lm_key, None)
if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None
setattr(config, em_key, None)
dlqqq marked this conversation as resolved.
Show resolved Hide resolved

return config

Expand Down Expand Up @@ -321,28 +328,29 @@ def _write_config(self, new_config: GlobalConfig):
complete `GlobalConfig` object, and should not be called publicly."""
# remove any empty field dictionaries
new_config.fields = {k: v for k, v in new_config.fields.items() if v}
new_config.completions_fields = {
k: v for k, v in new_config.completions_fields.items() if v
}

self._validate_config(new_config)
with open(self.config_path, "w") as f:
json.dump(new_config.dict(), f, indent=self.indentation_depth)

def delete_api_key(self, key_name: str):
config_dict = self._read_config().dict()
lm_provider = self.lm_provider
em_provider = self.em_provider
required_keys = []
if (
lm_provider
and lm_provider.auth_strategy
and lm_provider.auth_strategy.type == "env"
):
required_keys.append(lm_provider.auth_strategy.name)
if (
em_provider
and em_provider.auth_strategy
and em_provider.auth_strategy.type == "env"
):
required_keys.append(self.em_provider.auth_strategy.name)
for provider in [
self.lm_provider,
self.em_provider,
self.completions_lm_provider,
self.completions_em_provider,
krassowski marked this conversation as resolved.
Show resolved Hide resolved
]:
if (
provider
and provider.auth_strategy
and provider.auth_strategy.type == "env"
):
required_keys.append(provider.auth_strategy.name)

if key_name in required_keys:
raise KeyInUseError(
Expand Down Expand Up @@ -390,67 +398,69 @@ def em_gid(self):

@property
def lm_provider(self):
config = self._read_config()
lm_gid = config.model_provider_id
if lm_gid is None:
return None

_, Provider = get_lm_provider(config.model_provider_id, self._lm_providers)
return Provider
return self._get_provider("model_provider_id", self._lm_providers)

@property
def em_provider(self):
return self._get_provider("embeddings_provider_id", self._em_providers)

@property
def completions_lm_provider(self):
return self._get_provider("completions_model_provider_id", self._lm_providers)

@property
def completions_em_provider(self):
return self._get_provider(
"completions_embeddings_provider_id", self._em_providers
)

def _get_provider(self, key, listing):
config = self._read_config()
em_gid = config.embeddings_provider_id
if em_gid is None:
gid = getattr(config, key)
if gid is None:
return None

_, Provider = get_em_provider(em_gid, self._em_providers)
_, Provider = get_lm_provider(gid, listing)
return Provider

@property
def lm_provider_params(self):
# get generic fields
config = self._read_config()
lm_gid = config.model_provider_id
if not lm_gid:
return None
return self._provider_params("model_provider_id", self._lm_providers)

lm_lid = lm_gid.split(":", 1)[1]
fields = config.fields.get(lm_gid, {})

# get authn fields
_, Provider = get_lm_provider(lm_gid, self._lm_providers)
authn_fields = {}
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
key_name = Provider.auth_strategy.name
authn_fields[key_name.lower()] = config.api_keys[key_name]
@property
def em_provider_params(self):
return self._provider_params("embeddings_provider_id", self._em_providers)

return {
"model_id": lm_lid,
**fields,
**authn_fields,
}
@property
def completions_lm_provider_params(self):
return self._provider_params(
"completions_model_provider_id", self._lm_providers
)

@property
def em_provider_params(self):
def completions_em_provider_params(self):
return self._provider_params(
"completions_embeddings_provider_id", self._em_providers
)
krassowski marked this conversation as resolved.
Show resolved Hide resolved

def _provider_params(self, key, listing):
# get generic fields
config = self._read_config()
em_gid = config.embeddings_provider_id
if not em_gid:
gid = getattr(config, key)
if not gid:
return None

em_lid = em_gid.split(":", 1)[1]
lid = gid.split(":", 1)[1]

# get authn fields
_, Provider = get_em_provider(em_gid, self._em_providers)
_, Provider = get_em_provider(gid, listing)
authn_fields = {}
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
key_name = Provider.auth_strategy.name
authn_fields[key_name.lower()] = config.api_keys[key_name]

return {
"model_id": em_lid,
"model_id": lid,
**authn_fields,
}

Expand Down
6 changes: 6 additions & 0 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def filter_predicate(local_model_id: str):
# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
provider.models = list(filter(filter_predicate, provider.models))
provider.chat_models = list(filter(filter_predicate, provider.chat_models))
provider.completion_models = list(
filter(filter_predicate, provider.completion_models)
)

# filter out every provider with no models which satisfy the allow/blocklist, then return
return filter((lambda p: len(p.models) > 0), providers)
Expand All @@ -311,6 +315,8 @@ def get(self):
id=provider.id,
name=provider.name,
models=provider.models,
chat_models=provider.chat_models(),
completion_models=provider.completion_models(),
help=provider.help,
auth_strategy=provider.auth_strategy,
registry=provider.registry,
Expand Down
11 changes: 11 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class ListProvidersEntry(BaseModel):
auth_strategy: AuthStrategy
registry: bool
fields: List[Field]
chat_models: Optional[List[str]]
completion_models: Optional[List[str]]


class ListProvidersResponse(BaseModel):
Expand Down Expand Up @@ -121,6 +123,9 @@ class DescribeConfigResponse(BaseModel):
# timestamp indicating when the configuration file was last read. should be
# passed to the subsequent UpdateConfig request.
last_read: int
completions_model_provider_id: Optional[str]
completions_embeddings_provider_id: Optional[str]
completions_fields: Dict[str, Dict[str, Any]]


def forbid_none(cls, v):
Expand All @@ -137,6 +142,9 @@ class UpdateConfigRequest(BaseModel):
# if passed, this will raise an Error if the config was written to after the
# time specified by `last_read` to prevent write-write conflicts.
last_read: Optional[int]
completions_model_provider_id: Optional[str]
completions_embeddings_provider_id: Optional[str]
completions_fields: Optional[Dict[str, Dict[str, Any]]]

_validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)(
forbid_none
Expand All @@ -154,3 +162,6 @@ class GlobalConfig(BaseModel):
send_with_shift_enter: bool
fields: Dict[str, Dict[str, Any]]
api_keys: Dict[str, str]
completions_model_provider_id: Optional[str]
completions_embeddings_provider_id: Optional[str]
completions_fields: Dict[str, Dict[str, Any]]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
dict({
'api_keys': list([
]),
'completions_embeddings_provider_id': None,
'completions_fields': dict({
}),
'completions_model_provider_id': None,
'embeddings_provider_id': None,
'fields': dict({
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self):
self.messages = []
self.tasks = []
self.settings["jai_config_manager"] = SimpleNamespace(
lm_provider=MockProvider, lm_provider_params={"model_id": "model"}
completions_lm_provider=MockProvider,
completions_lm_provider_params={"model_id": "model"},
)
self.settings["jai_event_loop"] = SimpleNamespace(
create_task=lambda x: self.tasks.append(x)
Expand Down
Loading
Loading