diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 8fe5c7f3a..c2d657e69 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -43,8 +43,11 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - self.llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + self.llm = provider(**unified_parameters) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 584f0b33f..8352a8f8d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -23,8 +23,11 @@ def __init__(self, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) prompt_template = llm.get_chat_prompt_template() self.memory = ConversationBufferWindowMemory( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 50ef6b02c..539a74159 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -236,8 +236,11 @@ def __init__(self, preferred_dir: str, log_dir: Optional[str], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - model_parameters = self.get_model_parameters(provider, provider_params) - llm = provider(**provider_params, **model_parameters) + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) self.llm = llm return llm