From 41be9c3da91d9c74530bc1f13b6f6c4a607ec24e Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 26 Nov 2024 16:11:09 -0800 Subject: [PATCH] fix lm_provider_params prop to include fields --- .../jupyter-ai/jupyter_ai/config_manager.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 2f38ca992..bc480335d 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -436,16 +436,21 @@ def completions_lm_provider_params(self): ) def _provider_params(self, key, listing): - # get generic fields + # read config config = self._read_config() - gid = getattr(config, key) - if not gid: + + # get model ID (without provider ID component) from model universal ID + # (with provider component). + model_uid = getattr(config, key) + if not model_uid: return None + model_id = model_uid.split(":", 1)[1] - lid = gid.split(":", 1)[1] + # get config fields (e.g. base API URL, etc.) + fields = config.fields.get(model_uid, {}) # get authn fields - _, Provider = get_em_provider(gid, listing) + _, Provider = get_em_provider(model_uid, listing) authn_fields = {} if Provider.auth_strategy and Provider.auth_strategy.type == "env": keyword_param = ( @@ -456,7 +461,8 @@ def _provider_params(self, key, listing): authn_fields[keyword_param] = config.api_keys[key_name] return { - "model_id": lid, + "model_id": model_id, + **fields, **authn_fields, }