From 166309af34958c84205b92d48479d79db19d6520 Mon Sep 17 00:00:00 2001 From: Jason Weill Date: Thu, 8 Feb 2024 09:58:51 -0800 Subject: [PATCH 1/2] Unifies parameters to instantiate llm while incorporating model params --- packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py | 7 +++++-- packages/jupyter-ai/jupyter_ai/chat_handlers/default.py | 7 +++++-- packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py | 7 +++++-- .../jupyter-ai/jupyter_ai/completions/handlers/default.py | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 8fe5c7f3a..128a31326 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -43,8 +43,11 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - self.llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)) + } + self.llm = provider(**unified_parameters) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 584f0b33f..bcd257632 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -23,8 +23,11 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)) + } + llm = provider(**unified_parameters) prompt_template = llm.get_chat_prompt_template() self.memory = ConversationBufferWindowMemory( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 50ef6b02c..d61820445 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -236,8 +236,11 @@ def __init__(self, preferred_dir: str, log_dir: Optional[str], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)) + } + llm = provider(**unified_parameters) self.llm = llm return llm diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 552d23791..59e4b578a 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -28,8 +28,11 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)) + } + llm = provider(**unified_parameters) prompt_template = llm.get_completion_prompt_template() From a557d742cc60dcbd999d71addd67868047d68201 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Feb 2024 18:04:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py | 2 +- packages/jupyter-ai/jupyter_ai/chat_handlers/default.py | 2 +- packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py | 2 +- packages/jupyter-ai/jupyter_ai/completions/handlers/default.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 128a31326..c2d657e69 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -45,7 +45,7 @@ def create_llm_chain( ): unified_parameters = { **provider_params, - **(self.get_model_parameters(provider, provider_params)) + **(self.get_model_parameters(provider, provider_params)), } self.llm = provider(**unified_parameters) memory = ConversationBufferWindowMemory( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index bcd257632..8352a8f8d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -25,7 +25,7 @@ def create_llm_chain( ): unified_parameters = { **provider_params, - **(self.get_model_parameters(provider, provider_params)) + **(self.get_model_parameters(provider, provider_params)), } llm = provider(**unified_parameters) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index d61820445..539a74159 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -238,7 +238,7 @@ def create_llm_chain( ): unified_parameters = { **provider_params, - **(self.get_model_parameters(provider, provider_params)) + **(self.get_model_parameters(provider, provider_params)), } llm = provider(**unified_parameters) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 59e4b578a..eb03df156 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -30,7 +30,7 @@ def create_llm_chain( ): unified_parameters = { **provider_params, - **(self.get_model_parameters(provider, provider_params)) + **(self.get_model_parameters(provider, provider_params)), } llm = provider(**unified_parameters)