From 943e99bea68f9ea4d40106fb613d1079536a1864 Mon Sep 17 00:00:00 2001 From: Mark Waddle Date: Wed, 30 Oct 2024 16:19:07 -0700 Subject: [PATCH] Includes participants in conversation gets (#189) And adds parameter to specify what types of messages you care about when returning the latest message --- .../workbench_model.py | 2 + ..._231536_039bec8edc33_index_message_type.py | 25 ++++ .../controller/conversation.py | 124 ++++++++++++++++-- .../controller/convert.py | 24 +++- .../controller/workflow.py | 2 + .../semantic_workbench_service/db.py | 2 +- .../semantic_workbench_service/query.py | 3 + .../semantic_workbench_service/service.py | 18 ++- .../tests/test_workbench_service.py | 9 +- 9 files changed, 193 insertions(+), 16 deletions(-) create mode 100644 workbench-service/migrations/versions/2024_10_30_231536_039bec8edc33_index_message_type.py diff --git a/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py b/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py index 8188b971..9323a96e 100644 --- a/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py +++ b/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py @@ -97,8 +97,10 @@ class Conversation(BaseModel): imported_from_conversation_id: uuid.UUID | None metadata: dict[str, Any] created_datetime: datetime.datetime + conversation_permission: ConversationPermission latest_message: ConversationMessage | None + participants: list[ConversationParticipant] class ConversationList(BaseModel): diff --git a/workbench-service/migrations/versions/2024_10_30_231536_039bec8edc33_index_message_type.py b/workbench-service/migrations/versions/2024_10_30_231536_039bec8edc33_index_message_type.py new file mode 100644 index 00000000..8a5685ac --- /dev/null +++ b/workbench-service/migrations/versions/2024_10_30_231536_039bec8edc33_index_message_type.py @@ -0,0 +1,25 @@ +"""index message_type + +Revision ID: 039bec8edc33 +Revises: b29524775484 +Create Date: 2024-10-30 23:15:36.240812 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "039bec8edc33" +down_revision: Union[str, None] = "b29524775484" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_index(op.f("ix_conversationmessage_message_type"), "conversationmessage", ["message_type"], unique=False) + + +def downgrade() -> None: + op.drop_index(op.f("ix_conversationmessage_message_type"), table_name="conversationmessage") diff --git a/workbench-service/semantic_workbench_service/controller/conversation.py b/workbench-service/semantic_workbench_service/controller/conversation.py index 7e266898..e81bde11 100644 --- a/workbench-service/semantic_workbench_service/controller/conversation.py +++ b/workbench-service/semantic_workbench_service/controller/conversation.py @@ -2,7 +2,7 @@ import logging import uuid from dataclasses import dataclass -from typing import AsyncContextManager, Awaitable, Callable, Literal +from typing import AsyncContextManager, Awaitable, Callable, Iterable, Literal, Sequence from semantic_workbench_api_model.assistant_service_client import AssistantError from semantic_workbench_api_model.workbench_model import ( @@ -83,11 +83,86 @@ async def create_conversation( await session.commit() await session.refresh(conversation) - return await self.get_conversation(conversation_id=conversation.conversation_id, principal=user_principal) + return await self.get_conversation( + conversation_id=conversation.conversation_id, principal=user_principal, latest_message_types=set() + ) + + async def _projections_with_participants( + self, + session: AsyncSession, + conversation_projections: Sequence[tuple[db.Conversation, db.ConversationMessage | None, str]], + ) -> Iterable[ + tuple[ + db.Conversation, + Iterable[db.UserParticipant], + Iterable[db.AssistantParticipant], + dict[uuid.UUID, db.Assistant], + db.ConversationMessage | None, + str, + ] + ]: + user_participants = ( + await session.exec( + select(db.UserParticipant).where( + col(db.UserParticipant.conversation_id).in_([ + c[0].conversation_id for c in conversation_projections + ]) + ) + ) + ).all() + + assistant_participants = ( + await session.exec( + select(db.AssistantParticipant).where( + col(db.AssistantParticipant.conversation_id).in_([ + c[0].conversation_id for c in conversation_projections + ]) + ) + ) + ).all() + + assistants = ( + await session.exec( + select(db.Assistant).where( + col(db.Assistant.assistant_id).in_([p.assistant_id for p in assistant_participants]) + ) + ) + ).all() + assistants_map = {assistant.assistant_id: assistant for assistant in assistants} + + def merge() -> Iterable[ + tuple[ + db.Conversation, + Iterable[db.UserParticipant], + Iterable[db.AssistantParticipant], + dict[uuid.UUID, db.Assistant], + db.ConversationMessage | None, + str, + ] + ]: + for conversation, latest_message, permission in conversation_projections: + conversation_id = conversation.conversation_id + conversation_user_participants = ( + up for up in user_participants if up.conversation_id == conversation_id + ) + conversation_assistant_participants = ( + ap for ap in assistant_participants if ap.conversation_id == conversation_id + ) + yield ( + conversation, + conversation_user_participants, + conversation_assistant_participants, + assistants_map, + latest_message, + permission, + ) + + return merge() async def get_conversations( self, principal: auth.ActorPrincipal, + latest_message_types: set[MessageType], include_all_owned: bool = False, ) -> ConversationList: async with self._get_session() as session: @@ -96,17 +171,25 @@ async def get_conversations( conversation_projections = ( await session.exec( query.select_conversation_projections_for( - principal=principal, include_all_owned=include_all_owned, include_observer=True + principal=principal, + include_all_owned=include_all_owned, + include_observer=True, + latest_message_types=latest_message_types, ).order_by(col(db.Conversation.created_datetime).desc()) ) ).all() - return convert.conversation_list_from_db(models=conversation_projections) + projections_with_participants = await self._projections_with_participants( + session=session, conversation_projections=conversation_projections + ) + + return convert.conversation_list_from_db(models=projections_with_participants) async def get_assistant_conversations( self, user_principal: auth.UserPrincipal, assistant_id: uuid.UUID, + latest_message_types: set[MessageType], ) -> ConversationList: async with self._get_session() as session: assistant = ( @@ -119,22 +202,28 @@ async def get_assistant_conversations( if assistant is None: raise exceptions.NotFoundError() - conversations = ( + conversation_projections = ( await session.exec( query.select_conversation_projections_for( principal=auth.AssistantPrincipal( assistant_service_id=assistant.assistant_service_id, assistant_id=assistant_id ), + latest_message_types=latest_message_types, ) ) ).all() - return convert.conversation_list_from_db(models=conversations) + projections_with_participants = await self._projections_with_participants( + session=session, conversation_projections=conversation_projections + ) + + return convert.conversation_list_from_db(models=projections_with_participants) async def get_conversation( self, conversation_id: uuid.UUID, principal: auth.ActorPrincipal, + latest_message_types: set[MessageType], ) -> Conversation: async with self._get_session() as session: include_all_owned = isinstance(principal, auth.UserPrincipal) @@ -142,17 +231,32 @@ async def get_conversation( conversation_projection = ( await session.exec( query.select_conversation_projections_for( - principal=principal, include_all_owned=include_all_owned, include_observer=True + principal=principal, + include_all_owned=include_all_owned, + include_observer=True, + latest_message_types=latest_message_types, ).where(db.Conversation.conversation_id == conversation_id) ) ).one_or_none() if conversation_projection is None: raise exceptions.NotFoundError() - conversation, latest_message, permission = conversation_projection + projections_with_participants = await self._projections_with_participants( + session=session, + conversation_projections=[conversation_projection], + ) + + conversation, user_participants, assistant_participants, assistants, latest_message, permission = next( + iter(projections_with_participants) + ) return convert.conversation_from_db( - model=conversation, latest_message=latest_message, permission=permission + model=conversation, + latest_message=latest_message, + permission=permission, + user_participants=user_participants, + assistant_participants=assistant_participants, + assistants=assistants, ) async def update_conversation( @@ -190,7 +294,7 @@ async def update_conversation( await session.refresh(conversation) conversation_model = await self.get_conversation( - conversation_id=conversation.conversation_id, principal=user_principal + conversation_id=conversation.conversation_id, principal=user_principal, latest_message_types=set() ) await self._notify_event( diff --git a/workbench-service/semantic_workbench_service/controller/convert.py b/workbench-service/semantic_workbench_service/controller/convert.py index 6fd535bc..b717a3ba 100644 --- a/workbench-service/semantic_workbench_service/controller/convert.py +++ b/workbench-service/semantic_workbench_service/controller/convert.py @@ -147,6 +147,9 @@ def conversation_participant_list_from_db( def conversation_from_db( model: db.Conversation, + user_participants: Iterable[db.UserParticipant], + assistant_participants: Iterable[db.AssistantParticipant], + assistants: Mapping[uuid.UUID, db.Assistant], latest_message: db.ConversationMessage | None, permission: str, ) -> Conversation: @@ -159,20 +162,37 @@ def conversation_from_db( created_datetime=model.created_datetime, conversation_permission=ConversationPermission(permission), latest_message=conversation_message_from_db(model=latest_message) if latest_message else None, + participants=conversation_participant_list_from_db( + user_participants=user_participants, + assistant_participants=assistant_participants, + assistants=assistants, + ).participants, ) def conversation_list_from_db( - models: Iterable[tuple[db.Conversation, db.ConversationMessage | None, str]], + models: Iterable[ + tuple[ + db.Conversation, + Iterable[db.UserParticipant], + Iterable[db.AssistantParticipant], + dict[uuid.UUID, db.Assistant], + db.ConversationMessage | None, + str, + ] + ], ) -> ConversationList: return ConversationList( conversations=[ conversation_from_db( model=conversation, + user_participants=user_participants, + assistant_participants=assistant_participants, + assistants=assistants, latest_message=latest_message, permission=permission, ) - for conversation, latest_message, permission in models + for conversation, user_participants, assistant_participants, assistants, latest_message, permission in models ] ) diff --git a/workbench-service/semantic_workbench_service/controller/workflow.py b/workbench-service/semantic_workbench_service/controller/workflow.py index 6e5f897b..ef29c208 100644 --- a/workbench-service/semantic_workbench_service/controller/workflow.py +++ b/workbench-service/semantic_workbench_service/controller/workflow.py @@ -661,6 +661,7 @@ async def ensure_configuration_of_conversation_for_workflow_state( conversation = await self._conversation_controller.get_conversation( conversation_id=uuid.UUID(conversation_id), principal=service_user_principals.workflow, + latest_message_types=set(), ) except Exception as e: raise exceptions.RuntimeError( @@ -1607,6 +1608,7 @@ async def update_conversation_title( conversation = await self._conversation_controller.get_conversation( conversation_id=conversation_id, principal=service_user_principals.workflow, + latest_message_types=set(), ) except Exception as e: raise exceptions.RuntimeError( diff --git a/workbench-service/semantic_workbench_service/db.py b/workbench-service/semantic_workbench_service/db.py index bba3d89b..85b911a6 100644 --- a/workbench-service/semantic_workbench_service/db.py +++ b/workbench-service/semantic_workbench_service/db.py @@ -294,7 +294,7 @@ class ConversationMessage(SQLModel, table=True): created_datetime: datetime.datetime = date_time_default_to_now() sender_participant_id: str sender_participant_role: str - message_type: str + message_type: str = Field(index=True) content: str content_type: str meta_data: dict[str, Any] = Field(sa_column=sqlalchemy.Column("metadata", sqlalchemy.JSON), default={}) diff --git a/workbench-service/semantic_workbench_service/query.py b/workbench-service/semantic_workbench_service/query.py index d15958d2..c15a4bb3 100644 --- a/workbench-service/semantic_workbench_service/query.py +++ b/workbench-service/semantic_workbench_service/query.py @@ -3,6 +3,7 @@ from sqlalchemy import Function from sqlmodel import String, and_, cast, col, func, literal, or_, select from sqlmodel.sql.expression import Select, SelectOfScalar +from semantic_workbench_api_model.workbench_model import MessageType from . import auth, db, settings @@ -118,6 +119,7 @@ def select_conversations_for( def select_conversation_projections_for( principal: auth.ActorPrincipal, + latest_message_types: set[MessageType], include_all_owned: bool = False, include_observer: bool = False, ) -> Select[tuple[db.Conversation, db.ConversationMessage | None, str]]: @@ -148,6 +150,7 @@ def select_conversation_projections_for( db.ConversationMessage.conversation_id, func.max(db.ConversationMessage.sequence).label("latest_message_sequence"), ) + .where(col(db.ConversationMessage.message_type).in_(latest_message_types)) .group_by(col(db.ConversationMessage.conversation_id)) .subquery() ) diff --git a/workbench-service/semantic_workbench_service/service.py b/workbench-service/semantic_workbench_service/service.py index e9c01b67..8cf8158c 100644 --- a/workbench-service/semantic_workbench_service/service.py +++ b/workbench-service/semantic_workbench_service/service.py @@ -175,7 +175,13 @@ async def _notify_event(queue_item: ConversationEventQueueItem) -> None: queue_item.event.id, ) - if queue_item.event.event == ConversationEventType.message_created: + if queue_item.event.event in [ + ConversationEventType.message_created, + ConversationEventType.message_deleted, + ConversationEventType.conversation_updated, + ConversationEventType.participant_created, + ConversationEventType.participant_updated, + ]: task = asyncio.create_task( _notify_user_event(queue_item.event.conversation_id), name="notify_user_event" ) @@ -604,10 +610,12 @@ async def delete_assistant( async def get_assistant_conversations( assistant_id: uuid.UUID, user_principal: auth.DependsUserPrincipal, + latest_message_types: Annotated[list[MessageType], Query(alias="latest_message_type")] = [MessageType.chat], ) -> ConversationList: return await conversation_controller.get_assistant_conversations( user_principal=user_principal, assistant_id=assistant_id, + latest_message_types=set(latest_message_types), ) @app.get("/conversations/{conversation_id}/events") @@ -615,7 +623,9 @@ async def conversation_server_sent_events( conversation_id: uuid.UUID, request: Request, user_principal: auth.DependsUserPrincipal ) -> EventSourceResponse: # ensure the conversation exists - await conversation_controller.get_conversation(conversation_id=conversation_id, principal=user_principal) + await conversation_controller.get_conversation( + conversation_id=conversation_id, principal=user_principal, latest_message_types=set() + ) logger.debug( "client connected to sse; user_id: %s, conversation_id: %s", user_principal.user_id, conversation_id @@ -753,20 +763,24 @@ async def create_conversation( async def list_conversations( principal: auth.DependsActorPrincipal, include_inactive: bool = False, + latest_message_types: Annotated[list[MessageType], Query(alias="latest_message_type")] = [MessageType.chat], ) -> ConversationList: return await conversation_controller.get_conversations( principal=principal, include_all_owned=include_inactive, + latest_message_types=set(latest_message_types), ) @app.get("/conversations/{conversation_id}") async def get_conversation( conversation_id: uuid.UUID, principal: auth.DependsActorPrincipal, + latest_message_types: Annotated[list[MessageType], Query(alias="latest_message_type")] = [MessageType.chat], ) -> Conversation: return await conversation_controller.get_conversation( principal=principal, conversation_id=conversation_id, + latest_message_types=set(latest_message_types), ) @app.patch("/conversations/{conversation_id}") diff --git a/workbench-service/tests/test_workbench_service.py b/workbench-service/tests/test_workbench_service.py index 0dd74e54..70c33bb9 100644 --- a/workbench-service/tests/test_workbench_service.py +++ b/workbench-service/tests/test_workbench_service.py @@ -734,11 +734,18 @@ def test_create_conversation_send_user_message(workbench_service: FastAPI, test_ messages = workbench_model.ConversationMessageList.model_validate(http_response.json()) assert len(messages.messages) == 3 - # check latest message in conversation + # check latest chat message in conversation (chat is default) http_response = client.get(f"/conversations/{conversation_id}") assert httpx.codes.is_success(http_response.status_code) conversation = workbench_model.Conversation.model_validate(http_response.json()) assert conversation.latest_message is not None + assert conversation.latest_message.id == message_two_id + + # check latest log message in conversation + http_response = client.get(f"/conversations/{conversation_id}", params={"latest_message_type": ["log"]}) + assert httpx.codes.is_success(http_response.status_code) + conversation = workbench_model.Conversation.model_validate(http_response.json()) + assert conversation.latest_message is not None assert conversation.latest_message.id == message_log_id