diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index bad26726..33528e63 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -458,7 +458,7 @@ async def loop_func() -> bool: ) async def message_search( q: str, current_user: Annotated[UserModel, Depends(get_current_user)] -) -> SearchModel: +) -> SearchModel[MessageModel]: return await index.message_search(q, current_user.id, 25) diff --git a/src/conversation-api/models/search.py b/src/conversation-api/models/search.py index 729f1c9c..a0433bc9 100644 --- a/src/conversation-api/models/search.py +++ b/src/conversation-api/models/search.py @@ -1,11 +1,12 @@ from pydantic import BaseModel +from pydantic.generics import GenericModel from typing import List, TypeVar, Generic -T = TypeVar("T") +T = TypeVar("T", bound=BaseModel) -class SearchAnswerModel(BaseModel, Generic[T]): +class SearchAnswerModel(GenericModel, Generic[T]): data: T score: float @@ -15,7 +16,7 @@ class SearchStatsModel(BaseModel): total: int -class SearchModel(BaseModel, Generic[T]): +class SearchModel(GenericModel, Generic[T]): answers: List[SearchAnswerModel[T]] query: str stats: SearchStatsModel diff --git a/src/conversation-api/persistence/cosmos.py b/src/conversation-api/persistence/cosmos.py index cf298aca..b19958e3 100644 --- a/src/conversation-api/persistence/cosmos.py +++ b/src/conversation-api/persistence/cosmos.py @@ -92,10 +92,18 @@ async def conversation_list(self, user_id: UUID) -> List[StoredConversationModel query = ( f"SELECT * FROM c WHERE c.user_id = '{user_id}' ORDER BY c.created_at DESC" ) - items = conversation_client.query_items( + raws = conversation_client.query_items( query=query, enable_cross_partition_query=True ) - return [StoredConversationModel(**item) for item in items] + conversations = [] + for raw in raws: + if raw is None: + continue + try: + conversations.append(StoredConversationModel(**raw)) + except Exception: + logger.warn("Error parsing conversation", exc_info=True) + return conversations async def message_get( self, message_id: UUID, conversation_id: UUID @@ -134,10 +142,18 @@ async def message_set(self, message: StoredMessageModel) -> None: async def message_list(self, conversation_id: UUID) -> List[MessageModel]: query = f"SELECT * FROM c WHERE c.conversation_id = '{conversation_id}' ORDER BY c.created_at ASC" - items = message_client.query_items( + raws = message_client.query_items( query=query, enable_cross_partition_query=True ) - return [MessageModel(**item) for item in items] + items = [] + for raw in raws: + if raw is None: + continue + try: + items.append(MessageModel(**raw)) + except Exception: + logger.warn("Error parsing message", exc_info=True) + return items async def usage_set(self, usage: UsageModel) -> None: logger.debug(f'Usage set "{usage.id}"')