Skip to content

Commit

Permalink
Merge branch 'qjufa_chat_1414' into development
Browse files Browse the repository at this point in the history
  • Loading branch information
beastoin committed Dec 15, 2024
2 parents e83154b + 55602ba commit 7804353
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 20 deletions.
4 changes: 2 additions & 2 deletions backend/database/vector_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from collections import defaultdict
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta
from typing import List

from pinecone import Pinecone
Expand Down Expand Up @@ -88,7 +88,7 @@ def query_vectors_by_metadata(
if dates_filter and len(dates_filter) == 2 and dates_filter[0] and dates_filter[1]:
print('dates_filter', dates_filter)
filter_data['$and'].append(
{'created_at': {'$gte': int(dates_filter[0].timestamp()), '$lte': int(dates_filter[1].timestamp())}}
{'created_at': {'$gte': int(dates_filter[0].timestamp()), '$lte': int((dates_filter[1]+timedelta(days=1)).timestamp())-1}}
)

print('query_vectors_by_metadata:', json.dumps(filter_data))
Expand Down
38 changes: 27 additions & 11 deletions backend/utils/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import re
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional

import tiktoken
Expand Down Expand Up @@ -304,18 +304,16 @@ def retrieve_context_topics(messages: List[Message]) -> List[str]:
return topics


def retrieve_context_dates(messages: List[Message]) -> List[datetime]:
def retrieve_context_dates(messages: List[Message], tz: str) -> List[datetime]:
prompt = f'''
Based on the current conversation an AI and a User are having, for the AI to answer the latest user messages, it needs context outside the conversation.
Your task is to to find the dates range in which the current conversation needs context about, in order to answer the most recent user request.
For example, if the user request relates to "What did I do last week?", or "What did I learn yesterday", or "Who did I meet today?", the dates range should be provided.
For example, if the user request relates to "What did I do last week?", or "What did I learn yesterday", or "Who did I meet today?", the dates range should be provided.
Other type of dates, like historical events, or future events, should be ignored and an empty list should be returned.
For context, today is {datetime.now().isoformat()}.
Year: {datetime.now().year}, Month: {datetime.now().month}, Day: {datetime.now().day}
For context, today is {datetime.now(timezone.utc).strftime('%Y-%m-%d')} in UTC. {tz} is the user's timezone, convert it to UTC and respond in UTC.
Conversation:
{Message.get_messages_as_string(messages)}
Expand All @@ -324,6 +322,24 @@ def retrieve_context_dates(messages: List[Message]) -> List[datetime]:
response: DatesContext = with_parser.invoke(prompt)
return response.dates_range

def retrieve_context_dates_by_question(question: str, tz: str) -> List[datetime]:
prompt = f'''
Based on a question asked by the user to an AI, for the AI to answer the user question, it needs context outside the question.
Your task is to to find the dates range in which the question needs context about, in order to answer the most recent user question.
For example, if the user request relates to "What did I do last week?", or "What did I learn yesterday", or "Who did I meet today?", the dates range should be provided.
Other type of dates, like historical events, or future events, should be ignored and an empty list should be returned.
For context, today is {datetime.now(timezone.utc).strftime('%Y-%m-%d')} in UTC. {tz} is the user's timezone, convert it to UTC and respond in UTC.
Question:
{question}
'''.replace(' ', '').strip()
with_parser = llm_mini.with_structured_output(DatesContext)
response: DatesContext = with_parser.invoke(prompt)
return response.dates_range


class SummaryOutput(BaseModel):
summary: str = Field(description="The extracted content, maximum 500 words.")
Expand Down Expand Up @@ -713,7 +729,7 @@ def extract_question_from_conversation(messages: List[Message]) -> str:


def retrieve_metadata_fields_from_transcript(
uid: str, created_at: datetime, transcript_segment: List[dict]
uid: str, created_at: datetime, transcript_segment: List[dict], tz: str
) -> ExtractedInformation:
transcript = ''
for segment in transcript_segment:
Expand All @@ -728,7 +744,7 @@ def retrieve_metadata_fields_from_transcript(
Make sure as a first step, you infer and fix the raw transcript errors and then proceed to extract the information.
For context when extracting dates, today is {created_at.strftime('%Y-%m-%d')}.
For context when extracting dates, today is {created_at.astimezone(timezone.utc).strftime('%Y-%m-%d')} in UTC. {tz} is the user's timezone, convert it to UTC and respond in UTC.
If one says "today", it means the current day.
If one says "tomorrow", it means the next day after today.
If one says "yesterday", it means the day before today.
Expand Down
3 changes: 2 additions & 1 deletion backend/utils/memories/process_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def save_structured_vector(uid: str, memory: Memory, update_only: bool = False):
vector = generate_embedding(str(memory.structured)) if not update_only else None

segments = [t.dict() for t in memory.transcript_segments]
metadata = retrieve_metadata_fields_from_transcript(uid, memory.created_at, segments)
tz = notification_db.get_user_time_zone(uid)
metadata = retrieve_metadata_fields_from_transcript(uid, memory.created_at, segments, tz)
metadata['created_at'] = int(memory.created_at.timestamp())
if not update_only:
print('save_structured_vector creating vector')
Expand Down
19 changes: 13 additions & 6 deletions backend/utils/retrieval/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import database.memories as memories_db
from database.redis_db import get_filter_category_items
from database.vector_db import query_vectors_by_metadata
import database.notifications as notification_db
from models.chat import Message
from models.memory import Memory
from models.plugin import Plugin
Expand All @@ -20,6 +21,7 @@
requires_context,
answer_simple_message,
retrieve_context_dates,
retrieve_context_dates_by_question,
qa_rag,
retrieve_is_an_omi_question,
select_structured_filters,
Expand Down Expand Up @@ -48,6 +50,7 @@ class GraphState(TypedDict):
uid: str
messages: List[Message]
plugin_selected: Optional[Plugin]
tz: str

filters: Optional[StructuredFilters]
date_filters: Optional[DateRangeFilters]
Expand Down Expand Up @@ -119,11 +122,14 @@ def retrieve_topics_filters(state: GraphState):
def retrieve_date_filters(state: GraphState):
print('retrieve_date_filters')
# TODO: if this makes vector search fail further, query firestore instead
dates_range = retrieve_context_dates(state.get("messages", []))
if dates_range and len(dates_range) == 2:
print('retrieve_date_filters dates_range:', dates_range)
return {"date_filters": {"start": dates_range[0], "end": dates_range[1]}}
return {"date_filters": {}}
dates_range = retrieve_context_dates_by_question(state.get("parsed_question", ""), state.get("tz", "UTC"))
print('retrieve_date_filters dates_range:', dates_range)
if not dates_range or len(dates_range) == 0:
return {"date_filters": {}}
if len(dates_range) == 1:
return {"date_filters": {"start": dates_range[0], "end": dates_range[0]}}
# >=2
return {"date_filters": {"start": dates_range[0], "end": dates_range[1]}}


def query_vectors(state: GraphState):
Expand Down Expand Up @@ -200,8 +206,9 @@ def execute_graph_chat(
uid: str, messages: List[Message], plugin: Optional[Plugin] = None
) -> Tuple[str, bool, List[Memory]]:
print('execute_graph_chat plugin :', plugin)
tz = notification_db.get_user_time_zone(uid)
result = graph.invoke(
{"uid": uid, "messages": messages, "plugin_selected": plugin},
{"uid": uid, "tz": tz, "messages": messages, "plugin_selected": plugin},
{"configurable": {"thread_id": str(uuid.uuid4())}},
)
return result.get("answer"), result.get('ask_for_nps', False), result.get("memories_found", [])
Expand Down

0 comments on commit 7804353

Please sign in to comment.