From bb965c522fcfad078e30ea8ceb66cf7f66780e90 Mon Sep 17 00:00:00 2001 From: Jason Weill <93281816+JasonWeill@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:05:37 -0800 Subject: [PATCH] Unifies parameters to instantiate llm while incorporating model params (#632) * Unifies parameters to instantiate llm while incorporating model params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 17326b552e7fac6c6e6443b29ab7c8e32fbb3395) --- packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py | 7 +++++-- packages/jupyter-ai/jupyter_ai/chat_handlers/default.py | 7 +++++-- packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 8fe5c7f3a..c2d657e69 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -43,8 +43,11 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - self.llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + self.llm = provider(**unified_parameters) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 584f0b33f..8352a8f8d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -23,8 +23,11 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) prompt_template = llm.get_chat_prompt_template() self.memory = ConversationBufferWindowMemory( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 50ef6b02c..539a74159 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -236,8 +236,11 @@ def __init__(self, preferred_dir: str, log_dir: Optional[str], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) self.llm = llm return llm