From 7ea7cee4ca634eb01651375c2e7190da26c4c521 Mon Sep 17 00:00:00 2001 From: Jason Weill <93281816+JasonWeill@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:13:24 -0800 Subject: [PATCH] Backport PR #560: Fixes lookup for custom chains --- .../jupyter_ai_magics/magics.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 6a4e445d0..7fb701b89 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -434,7 +434,11 @@ def _append_exchange_openai(self, prompt: str, output: str): def _decompose_model_id(self, model_id: str): """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - if model_id in self.custom_model_registry: + # custom_model_registry maps keys to either a model name (a string) or an LLMChain. + # If this is an alias to another model, expand the full name of the model. + if model_id in self.custom_model_registry and isinstance( + self.custom_model_registry[model_id], str + ): model_id = self.custom_model_registry[model_id] return decompose_model_id(model_id, self.providers) @@ -508,6 +512,17 @@ def run_ai_cell(self, args: CellArgs, prompt: str): ) provider_id, local_model_id = self._decompose_model_id(args.model_id) + + # If this is a custom chain, send the message to the custom chain. + if args.model_id in self.custom_model_registry and isinstance( + self.custom_model_registry[args.model_id], LLMChain + ): + return self.display_output( + self.custom_model_registry[args.model_id].run(prompt), + args.format, + {"jupyter_ai": {"custom_chain_id": args.model_id}}, + ) + Provider = self._get_provider(provider_id) if Provider is None: return TextOrMarkdown(