Skip to content

Commit

Permalink
reverted fix.py
Browse files Browse the repository at this point in the history
  • Loading branch information
srdas committed Oct 23, 2024
1 parent 7c77d2b commit a1b388c
Showing 1 changed file with 19 additions and 55 deletions.
74 changes: 19 additions & 55 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import asyncio
from typing import Dict, Type

from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..context_providers import ContextProviderException, find_commands
from .base import BaseChatHandler, SlashCommandRoutingType

FIX_STRING_TEMPLATE = """
Expand Down Expand Up @@ -67,36 +64,22 @@ class FixChatHandler(BaseChatHandler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_template = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
unified_parameters = {
"verbose": True,
**provider_params,
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)

self.llm = llm
prompt_template = llm.get_chat_prompt_template()
self.prompt_template = prompt_template

runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
runnable = RunnableWithMessageHistory(
runnable=runnable, # type:ignore[arg-type]
get_session_history=self.get_llm_chat_memory,
input_messages_key="input",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="last_human_msg",
annotation=HumanChatMessage,
),
],
)
self.llm_chain = runnable
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
self.llm_chain = LLMChain( # type:ignore[arg-type]
llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True
)

async def process_message(self, message: HumanChatMessage):
if not (message.selection and message.selection.type == "cell-with-error"):
Expand All @@ -113,35 +96,16 @@ async def process_message(self, message: HumanChatMessage):
extra_instructions = message.prompt[4:].strip() or "None."

self.get_llm_chain()
assert self.llm_chain

inputs = {"input": message.body}
if "context" in self.prompt_template.input_variables:
# include context from context providers.
try:
context_prompt = await self.make_context_prompt(message)
except ContextProviderException as e:
self.reply(str(e), message)
return
inputs["context"] = context_prompt
inputs["input"] = self.replace_prompt(inputs["input"])

await self.stream_reply(inputs, message)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
await asyncio.gather(
*[
provider.make_context_prompt(human_msg)
for provider in self.context_providers.values()
if find_commands(provider, human_msg.prompt)
]
with self.pending("Analyzing error", message):
assert self.llm_chain
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
response = await self.llm_chain.apredict( # type:ignore[attr-defined]
extra_instructions=extra_instructions,
stop=["\nHuman:"],
cell_content=selection.source,
error_name=selection.error.name,
error_value=selection.error.value,
traceback="\n".join(selection.error.traceback),
)
)

def replace_prompt(self, prompt: str) -> str:
# modifies prompt by the context providers.
# some providers may modify or remove their '@' commands from the prompt.
for provider in self.context_providers.values():
prompt = provider.replace_prompt(prompt)
return prompt
self.reply(response, message)

0 comments on commit a1b388c

Please sign in to comment.