diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py index c7c72666b..17fa42656 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -48,10 +48,11 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: - time.sleep(5) + time.sleep(1) yield GenerationChunk( - text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n" + text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n", + generation_info={"test_metadata_field": "foobar"}, ) - for i in range(1, 101): - time.sleep(0.5) + for i in range(1, 6): + time.sleep(0.2) yield GenerationChunk(text=f"{i}, ") diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py new file mode 100644 index 000000000..4567ecba6 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py @@ -0,0 +1,6 @@ +""" +Provides classes which extend `langchain_core.callbacks:BaseCallbackHandler`. +Not to be confused with Jupyter AI chat handlers. +""" + +from .metadata import MetadataCallbackHandler diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py new file mode 100644 index 000000000..145cab313 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -0,0 +1,26 @@ +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + + +class MetadataCallbackHandler(BaseCallbackHandler): + """ + When passed as a callback handler, this stores the LLMResult's + `generation_info` dictionary in the `self.jai_metadata` instance attribute + after the provider fully processes an input. + + If used in a streaming chat handler: the `metadata` field of the final + `AgentStreamChunkMessage` should be set to `self.jai_metadata`. + + If used in a non-streaming chat handler: the `metadata` field of the + returned `AgentChatMessage` should be set to `self.jai_metadata`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.jai_metadata = {} + + def on_llm_end(self, response: LLMResult, **kwargs) -> None: + if not (len(response.generations) and len(response.generations[0])): + return + + self.jai_metadata = response.generations[0][0].generation_info or {} diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index dc6753b58..607dc92fc 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,8 +1,9 @@ import asyncio import time -from typing import Dict, Type +from typing import Any, Dict, Type from uuid import uuid4 +from jupyter_ai.callback_handlers import MetadataCallbackHandler from jupyter_ai.models import ( AgentStreamChunkMessage, AgentStreamMessage, @@ -85,13 +86,19 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: return stream_id - def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False): + def _send_stream_chunk( + self, + stream_id: str, + content: str, + complete: bool = False, + metadata: Dict[str, Any] = {}, + ): """ Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, content=content, stream_complete=complete + id=stream_id, content=content, stream_complete=complete, metadata=metadata ) for handler in self._root_chat_handlers.values(): @@ -104,6 +111,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals async def process_message(self, message: HumanChatMessage): self.get_llm_chain() received_first_chunk = False + assert self.llm_chain inputs = {"input": message.body} if "context" in self.prompt_template.input_variables: @@ -121,10 +129,13 @@ async def process_message(self, message: HumanChatMessage): # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. - assert self.llm_chain + metadata_handler = MetadataCallbackHandler() async for chunk in self.llm_chain.astream( inputs, - config={"configurable": {"last_human_msg": message}}, + config={ + "configurable": {"last_human_msg": message}, + "callbacks": [metadata_handler], + }, ): if not received_first_chunk: # when receiving the first chunk, close the pending message and @@ -142,7 +153,9 @@ async def process_message(self, message: HumanChatMessage): break # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True) + self._send_stream_chunk( + stream_id, "", complete=True, metadata=metadata_handler.jai_metadata + ) async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: return "\n\n".join( diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index e951ac6e8..9bd59ca28 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -87,6 +87,14 @@ class BaseAgentMessage(BaseModel): this defaults to a description of `JupyternautPersona`. """ + metadata: Dict[str, Any] = {} + """ + Message metadata set by a provider after fully processing an input. The + contents of this dictionary are provider-dependent, and can be any + dictionary with string keys. This field is not to be displayed directly to + the user, and is intended solely for developer purposes. + """ + class AgentChatMessage(BaseAgentMessage): type: Literal["agent"] = "agent" @@ -101,9 +109,17 @@ class AgentStreamMessage(BaseAgentMessage): class AgentStreamChunkMessage(BaseModel): type: Literal["agent-stream-chunk"] = "agent-stream-chunk" id: str + """ID of the parent `AgentStreamMessage`.""" content: str + """The string to append to the `AgentStreamMessage` referenced by `id`.""" stream_complete: bool - """Indicates whether this chunk message completes the referenced stream.""" + """Indicates whether this chunk completes the stream referenced by `id`.""" + metadata: Dict[str, Any] = {} + """ + The metadata of the stream referenced by `id`. Metadata from the latest + chunk should override any metadata from previous chunks. See the docstring + on `BaseAgentMessage.metadata` for information. + """ class HumanChatMessage(BaseModel): diff --git a/packages/jupyter-ai/src/chat_handler.ts b/packages/jupyter-ai/src/chat_handler.ts index 76c93a851..e1b1e332c 100644 --- a/packages/jupyter-ai/src/chat_handler.ts +++ b/packages/jupyter-ai/src/chat_handler.ts @@ -170,6 +170,7 @@ export class ChatHandler implements IDisposable { } streamMessage.body += newMessage.content; + streamMessage.metadata = newMessage.metadata; if (newMessage.stream_complete) { streamMessage.complete = true; } diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index aaae93a18..961b884b2 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -74,6 +74,10 @@ function sortMessages( export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { const collaborators = useCollaboratorsContext(); + if (props.message.type === 'agent-stream' && props.message.complete) { + console.log(props.message.metadata); + } + const sharedStyles: SxProps = { height: '24px', width: '24px' diff --git a/packages/jupyter-ai/src/components/pending-messages.tsx b/packages/jupyter-ai/src/components/pending-messages.tsx index 3635e41e0..c258c295e 100644 --- a/packages/jupyter-ai/src/components/pending-messages.tsx +++ b/packages/jupyter-ai/src/components/pending-messages.tsx @@ -60,7 +60,8 @@ export function PendingMessages( time: lastMessage.time, body: '', reply_to: '', - persona: lastMessage.persona + persona: lastMessage.persona, + metadata: {} }); // timestamp format copied from ChatMessage diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 43ab45203..b653015f4 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -114,6 +114,7 @@ export namespace AiService { body: string; reply_to: string; persona: Persona; + metadata: Record; }; export type HumanChatMessage = { @@ -172,6 +173,7 @@ export namespace AiService { id: string; content: string; stream_complete: boolean; + metadata: Record; }; export type Request = ChatRequest | ClearRequest;