Skip to content

Commit

Permalink
AssistantApp events are filtered to external events (#191)
Browse files Browse the repository at this point in the history
By default (ie. events not generated by this assistant instance)

The decorator accepts an include parameter which can be set to "all" to
receive all events, even those originating from this assistant instance
  • Loading branch information
markwaddle authored Oct 31, 2024
1 parent 0e6f10d commit 4b1f24b
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 59 deletions.
6 changes: 1 addition & 5 deletions assistants/explorer-assistant/assistant/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ async def on_message_created(
- @assistant.events.conversation.message.on_created
"""

# ignore messages from this assistant
if message.sender.participant_id == context.assistant.id:
return

# update the participant status to indicate the assistant is thinking
async with context.set_status_for_block("thinking..."):
config = await assistant_config.get(context.assistant)
Expand Down Expand Up @@ -413,7 +409,7 @@ async def respond_to_conversation(

# check if the completion total tokens exceed the warning threshold
if completion_total_tokens > token_count_for_warning:
content = f"{config.high_token_usage_warning.message}\n\n" f"Total tokens used: {completion_total_tokens}"
content = f"{config.high_token_usage_warning.message}\n\nTotal tokens used: {completion_total_tokens}"

# send a notice message to the conversation that the token usage is high
await context.send_messages(
Expand Down
4 changes: 0 additions & 4 deletions assistants/guided-conversation-assistant/assistant/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,6 @@ async def on_message_created(
- @assistant.events.conversation.message.on_created
"""

# ignore messages from this assistant
if message.sender.participant_id == context.assistant.id:
return

# update the participant status to indicate the assistant is thinking
await context.update_participant_me(UpdateParticipant(status="thinking..."))
try:
Expand Down
4 changes: 0 additions & 4 deletions assistants/prospector-assistant/assistant/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ async def on_chat_message_created(
- @assistant.events.conversation.message.on_created
"""

# ignore messages from this assistant
if message.sender.participant_id == context.assistant.id:
return

# update the participant status to indicate the assistant is thinking
async with context.set_status_for_block("thinking..."):
config = await assistant_config.get(context.assistant)
Expand Down
8 changes: 0 additions & 8 deletions assistants/skill-assistant/assistant/skill_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ async def on_message_created(
- @assistant.events.conversation.message.on_created
"""

# ignore messages from this assistant
if message.sender.participant_id == context.assistant.id:
return

# pass the message to the core response logic
await respond_to_conversation(context, event, message)

Expand All @@ -118,10 +114,6 @@ async def on_command_message_created(
Handle the event triggered when a new command message is created in the conversation.
"""

# ignore messages from this assistant
if message.sender.participant_id == context.assistant.id:
return

# pass the message to the core response logic
await respond_to_conversation(context, event, message)

Expand Down
3 changes: 0 additions & 3 deletions examples/python/python-02-simple-chatbot/assistant/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ async def on_message_created(
- To handle all message types, you can use the root event handler for all message types:
- @assistant.events.conversation.message.on_created
"""
# ignore messages from this assistant
if message.sender.participant_id == context.assistant.id:
return

# update the participant status to indicate the assistant is thinking
await context.update_participant_me(UpdateParticipant(status="thinking..."))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ async def on_message_created(
- To handle all message types, you can use the root event handler for all message types:
- @assistant.events.conversation.message.on_created
"""
# ignore messages from this assistant
if message.sender.participant_id == context.assistant.id:
return

# update the participant status to indicate the assistant is thinking
await context.update_participant_me(UpdateParticipant(status="thinking..."))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
Awaitable,
Callable,
Generic,
Literal,
Mapping,
Protocol,
TypeVar,
Union,
overload,
)

import typing_extensions
Expand Down Expand Up @@ -93,15 +95,25 @@ async def set(self, assistant_context: AssistantContext, config: dict[str, Any])
EventHandlerT = TypeVar("EventHandlerT")


class EventHandlerList(Generic[EventHandlerT], list[EventHandlerT]):
async def __call__(self, *args, **kwargs):
for handler in self:
IncludeEventsFromActors = Literal["all", "others", "this_assistant_service"]


class EventHandlerList(Generic[EventHandlerT], list[tuple[EventHandlerT, IncludeEventsFromActors]]):
async def __call__(self, external_event: bool, *args, **kwargs):
for handler, include in self:
if external_event and include == "this_assistant_service":
continue
if not external_event and include == "others":
continue

try:
if asyncio.iscoroutinefunction(handler):
return await handler(*args, **kwargs)
await handler(*args, **kwargs)
continue

if callable(handler):
return handler(*args, **kwargs)
handler(*args, **kwargs)
continue

except Exception:
logger.exception("error in event handler {handler}")
Expand Down Expand Up @@ -137,10 +149,37 @@ def __init__(self) -> None:
self.on_service_shutdown = _create_decorator(self._on_service_shutdown_handlers)


def _create_decorator(handler_list: list[EventHandlerT]) -> Callable[[EventHandlerT], EventHandlerT]:
def decorator(func: EventHandlerT) -> EventHandlerT:
handler_list.append(func)
return func
def _create_decorator(
handler_list: EventHandlerList[EventHandlerT],
):
@overload
def decorator(func_or_include: EventHandlerT) -> EventHandlerT: ...

@overload
def decorator(
func_or_include: IncludeEventsFromActors | None = "others",
) -> Callable[[EventHandlerT], EventHandlerT]: ...

def decorator(
func_or_include: EventHandlerT | IncludeEventsFromActors | None = "others",
) -> EventHandlerT | Callable[[EventHandlerT], EventHandlerT]:
filter: IncludeEventsFromActors = "others"
match func_or_include:
case "all":
filter = "all"
case "this_assistant_service":
filter = "this_assistant_service"

def _decorator(func: EventHandlerT) -> EventHandlerT:
handler_list.append((func, filter))
return func

# decorator with no arguments
if callable(func_or_include):
return _decorator(func_or_include)

# decorator with arguments
return _decorator

return decorator

Expand Down Expand Up @@ -185,6 +224,23 @@ def __init__(self) -> None:
for event_type in workbench_model.MessageType:
assert getattr(self, str(event_type).replace("-", "_"))

def __getitem__(self, key: workbench_model.MessageType) -> ObjectEventHandlers[ConversationMessageEventHandler]:
match key:
case workbench_model.MessageType.chat:
return self.chat
case workbench_model.MessageType.log:
return self.log
case workbench_model.MessageType.note:
return self.note
case workbench_model.MessageType.notice:
return self.notice
case workbench_model.MessageType.command:
return self.command
case workbench_model.MessageType.command_response:
return self.command_response
case _:
raise KeyError(key)


class ConversationEvents(ObjectEventHandlers[ConversationEventHandler]):
def __init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def __init__(

@asynccontextmanager
async def lifespan(self) -> AsyncIterator[None]:
await self.assistant_app.events._on_service_start_handlers()
await self.assistant_app.events._on_service_start_handlers(True)

try:
yield
finally:
await self.assistant_app.events._on_service_shutdown_handlers()
await self.assistant_app.events._on_service_shutdown_handlers(True)

for task in self._conversation_event_tasks:
task.cancel()
Expand Down Expand Up @@ -259,9 +259,9 @@ async def put_assistant(

assistant_context = require_found(self.get_assistant_context(assistant_id))
if is_new:
await self.assistant_app.events.assistant._on_created_handlers(assistant_context)
await self.assistant_app.events.assistant._on_created_handlers(True, assistant_context)
else:
await self.assistant_app.events.assistant._on_updated_handlers(assistant_context)
await self.assistant_app.events.assistant._on_updated_handlers(True, assistant_context)

if from_export is not None:
await self.assistant_app.data_exporter.import_(assistant_context, from_export)
Expand Down Expand Up @@ -304,7 +304,7 @@ async def delete_assistant(self, assistant_id: str) -> None:
states.assistants.pop(assistant_id, None)
self.write_assistant_states(states)

await self.assistant_app.events.assistant._on_deleted_handlers(assistant_context)
await self.assistant_app.events.assistant._on_deleted_handlers(True, assistant_context)

@translate_assistant_errors
async def get_config(self, assistant_id: str) -> assistant_model.ConfigResponseModel:
Expand Down Expand Up @@ -352,9 +352,9 @@ async def put_conversation(
conversation_context = require_found(self.get_conversation_context(assistant_id, conversation_id))

if is_new:
await self.assistant_app.events.conversation._on_created_handlers(conversation_context)
await self.assistant_app.events.conversation._on_created_handlers(True, conversation_context)
else:
await self.assistant_app.events.conversation._on_updated_handlers(conversation_context)
await self.assistant_app.events.conversation._on_updated_handlers(True, conversation_context)

if from_export is not None:
await self.assistant_app.conversation_data_exporter.import_(conversation_context, from_export)
Expand Down Expand Up @@ -391,7 +391,7 @@ async def delete_conversation(self, assistant_id: str, conversation_id: str) ->
return
self.write_assistant_states(states)

await self.assistant_app.events.conversation._on_deleted_handlers(conversation_context)
await self.assistant_app.events.conversation._on_deleted_handlers(True, conversation_context)

async def _get_or_create_queue(self, assistant_id: str, conversation_id: str) -> asyncio.Queue[_Event]:
key = (assistant_id, conversation_id)
Expand Down Expand Up @@ -487,17 +487,18 @@ async def _forward_event(
logging.exception("invalid message event data")
return

type_specific_events = getattr(
self.assistant_app.events.conversation.message, str(message.message_type).replace("-", "_")
)
event_originated_externally = message.sender.participant_id != conversation_context.assistant.id

async with asyncio.TaskGroup() as tg:
tg.create_task(
self.assistant_app.events.conversation.message._on_created_handlers(
conversation_context, updated_event, message
event_originated_externally, conversation_context, updated_event, message
)
)
tg.create_task(
type_specific_events._on_created_handlers(conversation_context, updated_event, message)
self.assistant_app.events.conversation.message[message.message_type]._on_created_handlers(
event_originated_externally, conversation_context, updated_event, message
)
)

case workbench_model.ConversationEventType.message_deleted:
Expand All @@ -507,17 +508,18 @@ async def _forward_event(
logging.exception("invalid message event data")
return

type_specific_events = getattr(
self.assistant_app.events.conversation.message, str(message.message_type).replace("-", "_")
)
event_originated_externally = message.sender.participant_id != conversation_context.assistant.id

async with asyncio.TaskGroup() as tg:
tg.create_task(
self.assistant_app.events.conversation.message._on_deleted_handlers(
conversation_context, updated_event, message
event_originated_externally, conversation_context, updated_event, message
)
)
tg.create_task(
type_specific_events._on_deleted_handlers(conversation_context, updated_event, message)
self.assistant_app.events.conversation.message[message.message_type]._on_deleted_handlers(
event_originated_externally, conversation_context, updated_event, message
)
)

case workbench_model.ConversationEventType.participant_created:
Expand All @@ -529,8 +531,9 @@ async def _forward_event(
logging.exception("invalid participant event data")
return

event_originated_externally = participant.id != conversation_context.assistant.id
await self.assistant_app.events.conversation.participant._on_created_handlers(
conversation_context, updated_event, participant
event_originated_externally, conversation_context, updated_event, participant
)

case workbench_model.ConversationEventType.participant_updated:
Expand All @@ -542,8 +545,9 @@ async def _forward_event(
logging.exception("invalid participant event data")
return

event_originated_externally = participant.id != conversation_context.assistant.id
await self.assistant_app.events.conversation.participant._on_updated_handlers(
conversation_context, updated_event, participant
event_originated_externally, conversation_context, updated_event, participant
)

case workbench_model.ConversationEventType.file_created:
Expand All @@ -553,8 +557,9 @@ async def _forward_event(
logging.exception("invalid file event data")
return

event_originated_externally = file.participant_id != conversation_context.assistant.id
await self.assistant_app.events.conversation.file._on_created_handlers(
conversation_context, updated_event, file
event_originated_externally, conversation_context, updated_event, file
)

case workbench_model.ConversationEventType.file_updated:
Expand All @@ -564,8 +569,9 @@ async def _forward_event(
logging.exception("invalid file event data")
return

event_originated_externally = file.participant_id != conversation_context.assistant.id
await self.assistant_app.events.conversation.file._on_updated_handlers(
conversation_context, updated_event, file
event_originated_externally, conversation_context, updated_event, file
)

case workbench_model.ConversationEventType.file_deleted:
Expand All @@ -575,8 +581,9 @@ async def _forward_event(
logging.exception("invalid file event data")
return

event_originated_externally = file.participant_id != conversation_context.assistant.id
await self.assistant_app.events.conversation.file._on_deleted_handlers(
conversation_context, updated_event, file
event_originated_externally, conversation_context, updated_event, file
)

@translate_assistant_errors
Expand Down
Loading

0 comments on commit 4b1f24b

Please sign in to comment.