Skip to content

Commit

Permalink
Unifies parameters to instantiate llm while incorporating model params (
Browse files Browse the repository at this point in the history
jupyterlab#632)

* Unifies parameters to instantiate llm while incorporating model params

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
JasonWeill and pre-commit-ci[bot] authored Feb 8, 2024
1 parent a6868ff commit 17326b5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ def __init__(self, retriever, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
model_parameters = self.get_model_parameters(provider, provider_params)
self.llm = provider(**provider_params, **model_parameters)
unified_parameters = {
**provider_params,
**(self.get_model_parameters(provider, provider_params)),
}
self.llm = provider(**unified_parameters)
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def __init__(self, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)
unified_parameters = {
**provider_params,
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)

prompt_template = llm.get_chat_prompt_template()
self.memory = ConversationBufferWindowMemory(
Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,11 @@ def __init__(self, preferred_dir: str, log_dir: Optional[str], *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)
unified_parameters = {
**provider_params,
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)

self.llm = llm
return llm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ def __init__(self, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)
unified_parameters = {
**provider_params,
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)

prompt_template = llm.get_completion_prompt_template()

Expand Down

0 comments on commit 17326b5

Please sign in to comment.