Skip to content

Commit

Permalink
Includes participants in conversation gets (#189)
Browse files Browse the repository at this point in the history
And adds parameter to specify what types of messages you care about when
returning the latest message
  • Loading branch information
markwaddle authored Oct 30, 2024
1 parent d46eb4c commit 943e99b
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ class Conversation(BaseModel):
imported_from_conversation_id: uuid.UUID | None
metadata: dict[str, Any]
created_datetime: datetime.datetime

conversation_permission: ConversationPermission
latest_message: ConversationMessage | None
participants: list[ConversationParticipant]


class ConversationList(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""index message_type
Revision ID: 039bec8edc33
Revises: b29524775484
Create Date: 2024-10-30 23:15:36.240812
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "039bec8edc33"
down_revision: Union[str, None] = "b29524775484"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_index(op.f("ix_conversationmessage_message_type"), "conversationmessage", ["message_type"], unique=False)


def downgrade() -> None:
op.drop_index(op.f("ix_conversationmessage_message_type"), table_name="conversationmessage")
124 changes: 114 additions & 10 deletions workbench-service/semantic_workbench_service/controller/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import uuid
from dataclasses import dataclass
from typing import AsyncContextManager, Awaitable, Callable, Literal
from typing import AsyncContextManager, Awaitable, Callable, Iterable, Literal, Sequence

from semantic_workbench_api_model.assistant_service_client import AssistantError
from semantic_workbench_api_model.workbench_model import (
Expand Down Expand Up @@ -83,11 +83,86 @@ async def create_conversation(
await session.commit()
await session.refresh(conversation)

return await self.get_conversation(conversation_id=conversation.conversation_id, principal=user_principal)
return await self.get_conversation(
conversation_id=conversation.conversation_id, principal=user_principal, latest_message_types=set()
)

async def _projections_with_participants(
self,
session: AsyncSession,
conversation_projections: Sequence[tuple[db.Conversation, db.ConversationMessage | None, str]],
) -> Iterable[
tuple[
db.Conversation,
Iterable[db.UserParticipant],
Iterable[db.AssistantParticipant],
dict[uuid.UUID, db.Assistant],
db.ConversationMessage | None,
str,
]
]:
user_participants = (
await session.exec(
select(db.UserParticipant).where(
col(db.UserParticipant.conversation_id).in_([
c[0].conversation_id for c in conversation_projections
])
)
)
).all()

assistant_participants = (
await session.exec(
select(db.AssistantParticipant).where(
col(db.AssistantParticipant.conversation_id).in_([
c[0].conversation_id for c in conversation_projections
])
)
)
).all()

assistants = (
await session.exec(
select(db.Assistant).where(
col(db.Assistant.assistant_id).in_([p.assistant_id for p in assistant_participants])
)
)
).all()
assistants_map = {assistant.assistant_id: assistant for assistant in assistants}

def merge() -> Iterable[
tuple[
db.Conversation,
Iterable[db.UserParticipant],
Iterable[db.AssistantParticipant],
dict[uuid.UUID, db.Assistant],
db.ConversationMessage | None,
str,
]
]:
for conversation, latest_message, permission in conversation_projections:
conversation_id = conversation.conversation_id
conversation_user_participants = (
up for up in user_participants if up.conversation_id == conversation_id
)
conversation_assistant_participants = (
ap for ap in assistant_participants if ap.conversation_id == conversation_id
)
yield (
conversation,
conversation_user_participants,
conversation_assistant_participants,
assistants_map,
latest_message,
permission,
)

return merge()

async def get_conversations(
self,
principal: auth.ActorPrincipal,
latest_message_types: set[MessageType],
include_all_owned: bool = False,
) -> ConversationList:
async with self._get_session() as session:
Expand All @@ -96,17 +171,25 @@ async def get_conversations(
conversation_projections = (
await session.exec(
query.select_conversation_projections_for(
principal=principal, include_all_owned=include_all_owned, include_observer=True
principal=principal,
include_all_owned=include_all_owned,
include_observer=True,
latest_message_types=latest_message_types,
).order_by(col(db.Conversation.created_datetime).desc())
)
).all()

return convert.conversation_list_from_db(models=conversation_projections)
projections_with_participants = await self._projections_with_participants(
session=session, conversation_projections=conversation_projections
)

return convert.conversation_list_from_db(models=projections_with_participants)

async def get_assistant_conversations(
self,
user_principal: auth.UserPrincipal,
assistant_id: uuid.UUID,
latest_message_types: set[MessageType],
) -> ConversationList:
async with self._get_session() as session:
assistant = (
Expand All @@ -119,40 +202,61 @@ async def get_assistant_conversations(
if assistant is None:
raise exceptions.NotFoundError()

conversations = (
conversation_projections = (
await session.exec(
query.select_conversation_projections_for(
principal=auth.AssistantPrincipal(
assistant_service_id=assistant.assistant_service_id, assistant_id=assistant_id
),
latest_message_types=latest_message_types,
)
)
).all()

return convert.conversation_list_from_db(models=conversations)
projections_with_participants = await self._projections_with_participants(
session=session, conversation_projections=conversation_projections
)

return convert.conversation_list_from_db(models=projections_with_participants)

async def get_conversation(
self,
conversation_id: uuid.UUID,
principal: auth.ActorPrincipal,
latest_message_types: set[MessageType],
) -> Conversation:
async with self._get_session() as session:
include_all_owned = isinstance(principal, auth.UserPrincipal)

conversation_projection = (
await session.exec(
query.select_conversation_projections_for(
principal=principal, include_all_owned=include_all_owned, include_observer=True
principal=principal,
include_all_owned=include_all_owned,
include_observer=True,
latest_message_types=latest_message_types,
).where(db.Conversation.conversation_id == conversation_id)
)
).one_or_none()
if conversation_projection is None:
raise exceptions.NotFoundError()

conversation, latest_message, permission = conversation_projection
projections_with_participants = await self._projections_with_participants(
session=session,
conversation_projections=[conversation_projection],
)

conversation, user_participants, assistant_participants, assistants, latest_message, permission = next(
iter(projections_with_participants)
)

return convert.conversation_from_db(
model=conversation, latest_message=latest_message, permission=permission
model=conversation,
latest_message=latest_message,
permission=permission,
user_participants=user_participants,
assistant_participants=assistant_participants,
assistants=assistants,
)

async def update_conversation(
Expand Down Expand Up @@ -190,7 +294,7 @@ async def update_conversation(
await session.refresh(conversation)

conversation_model = await self.get_conversation(
conversation_id=conversation.conversation_id, principal=user_principal
conversation_id=conversation.conversation_id, principal=user_principal, latest_message_types=set()
)

await self._notify_event(
Expand Down
24 changes: 22 additions & 2 deletions workbench-service/semantic_workbench_service/controller/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def conversation_participant_list_from_db(

def conversation_from_db(
model: db.Conversation,
user_participants: Iterable[db.UserParticipant],
assistant_participants: Iterable[db.AssistantParticipant],
assistants: Mapping[uuid.UUID, db.Assistant],
latest_message: db.ConversationMessage | None,
permission: str,
) -> Conversation:
Expand All @@ -159,20 +162,37 @@ def conversation_from_db(
created_datetime=model.created_datetime,
conversation_permission=ConversationPermission(permission),
latest_message=conversation_message_from_db(model=latest_message) if latest_message else None,
participants=conversation_participant_list_from_db(
user_participants=user_participants,
assistant_participants=assistant_participants,
assistants=assistants,
).participants,
)


def conversation_list_from_db(
models: Iterable[tuple[db.Conversation, db.ConversationMessage | None, str]],
models: Iterable[
tuple[
db.Conversation,
Iterable[db.UserParticipant],
Iterable[db.AssistantParticipant],
dict[uuid.UUID, db.Assistant],
db.ConversationMessage | None,
str,
]
],
) -> ConversationList:
return ConversationList(
conversations=[
conversation_from_db(
model=conversation,
user_participants=user_participants,
assistant_participants=assistant_participants,
assistants=assistants,
latest_message=latest_message,
permission=permission,
)
for conversation, latest_message, permission in models
for conversation, user_participants, assistant_participants, assistants, latest_message, permission in models
]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ async def ensure_configuration_of_conversation_for_workflow_state(
conversation = await self._conversation_controller.get_conversation(
conversation_id=uuid.UUID(conversation_id),
principal=service_user_principals.workflow,
latest_message_types=set(),
)
except Exception as e:
raise exceptions.RuntimeError(
Expand Down Expand Up @@ -1607,6 +1608,7 @@ async def update_conversation_title(
conversation = await self._conversation_controller.get_conversation(
conversation_id=conversation_id,
principal=service_user_principals.workflow,
latest_message_types=set(),
)
except Exception as e:
raise exceptions.RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion workbench-service/semantic_workbench_service/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class ConversationMessage(SQLModel, table=True):
created_datetime: datetime.datetime = date_time_default_to_now()
sender_participant_id: str
sender_participant_role: str
message_type: str
message_type: str = Field(index=True)
content: str
content_type: str
meta_data: dict[str, Any] = Field(sa_column=sqlalchemy.Column("metadata", sqlalchemy.JSON), default={})
Expand Down
3 changes: 3 additions & 0 deletions workbench-service/semantic_workbench_service/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlalchemy import Function
from sqlmodel import String, and_, cast, col, func, literal, or_, select
from sqlmodel.sql.expression import Select, SelectOfScalar
from semantic_workbench_api_model.workbench_model import MessageType

from . import auth, db, settings

Expand Down Expand Up @@ -118,6 +119,7 @@ def select_conversations_for(

def select_conversation_projections_for(
principal: auth.ActorPrincipal,
latest_message_types: set[MessageType],
include_all_owned: bool = False,
include_observer: bool = False,
) -> Select[tuple[db.Conversation, db.ConversationMessage | None, str]]:
Expand Down Expand Up @@ -148,6 +150,7 @@ def select_conversation_projections_for(
db.ConversationMessage.conversation_id,
func.max(db.ConversationMessage.sequence).label("latest_message_sequence"),
)
.where(col(db.ConversationMessage.message_type).in_(latest_message_types))
.group_by(col(db.ConversationMessage.conversation_id))
.subquery()
)
Expand Down
Loading

0 comments on commit 943e99b

Please sign in to comment.