From 5bb4934ba38569aae291bfa7f03a62a1d61023aa Mon Sep 17 00:00:00 2001 From: Jason Weill <93281816+JasonWeill@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:05:37 -0800 Subject: [PATCH] Unifies parameters to instantiate llm while incorporating model params (#632) * Unifies parameters to instantiate llm while incorporating model params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py | 7 +++++-- packages/jupyter-ai/jupyter_ai/chat_handlers/default.py | 7 +++++-- packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py | 7 +++++-- .../jupyter-ai/jupyter_ai/completions/handlers/default.py | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 8fe5c7f3a..c2d657e69 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -43,8 +43,11 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - self.llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + self.llm = provider(**unified_parameters) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 584f0b33f..8352a8f8d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -23,8 +23,11 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) prompt_template = llm.get_chat_prompt_template() self.memory = ConversationBufferWindowMemory( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 50ef6b02c..539a74159 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -236,8 +236,11 @@ def __init__(self, preferred_dir: str, log_dir: Optional[str], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) self.llm = llm return llm diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 552d23791..eb03df156 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -28,8 +28,11 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) prompt_template = llm.get_completion_prompt_template()