From e8d178dc86ae17f9c0437e5948871cede2662af0 Mon Sep 17 00:00:00 2001 From: Jason Weill Date: Wed, 7 Feb 2024 15:56:31 -0800 Subject: [PATCH] Removes model_parameters, which were redundant with provider_params, from chat and completion code --- packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py | 3 +-- packages/jupyter-ai/jupyter_ai/chat_handlers/default.py | 3 +-- packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py | 3 +-- packages/jupyter-ai/jupyter_ai/completions/handlers/default.py | 3 +-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 8fe5c7f3a..34b9627e6 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -43,8 +43,7 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - self.llm = provider(**provider_params, **model_parameters) + self.llm = provider(**provider_params) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 584f0b33f..ec454392d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -23,8 +23,7 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + llm = provider(**provider_params) prompt_template = llm.get_chat_prompt_template() self.memory = ConversationBufferWindowMemory( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 50ef6b02c..920c50b2b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -236,8 +236,7 @@ def __init__(self, preferred_dir: str, log_dir: Optional[str], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + llm = provider(**provider_params) self.llm = llm return llm diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 552d23791..a04a04954 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -28,8 +28,7 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + llm = provider(**provider_params) prompt_template = llm.get_completion_prompt_template()