Skip to content

Commit

Permalink
Update - modified event db manager and db engine
Browse files Browse the repository at this point in the history
  • Loading branch information
aybruhm committed Jan 6, 2024
1 parent 12328d8 commit 17190f2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 35 deletions.
4 changes: 0 additions & 4 deletions agenta-backend/agenta_backend/models/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,5 @@ def remove_db(self) -> None:
client = MongoClient(self.db_url)
if self.mode == "default":
client.drop_database("agenta")
elif self.mode == "v2":
client.drop_database("agenta_v2")
elif self.mode == "test":
client.drop_database("agenta_test")
else:
client.drop_database(f"agenta_{self.mode}")
63 changes: 32 additions & 31 deletions agenta-backend/agenta_backend/services/event_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ async def get_variant_traces(
"""

user = await db_manager.get_user(user_uid=kwargs["uid"])
query_expressions = (
TraceDB.user == user.id,
traces = await TraceDB.find(
TraceDB.user.id == user.id,
TraceDB.app_id == app_id,
TraceDB.variant_id == variant_id,
)
traces = await TraceDB.find(query_expressions).to_list()
fetch_links=True
).to_list()
return [trace_db_to_pydantic(trace) for trace in traces]


Expand All @@ -70,13 +70,13 @@ async def create_app_trace(payload: CreateTrace, **kwargs: dict) -> str:

# Ensure spans exists in the db
for span in payload.spans:
span_db = await SpanDB.find_one(SpanDB.id == ObjectId(span))
span_db = await SpanDB.find_one(SpanDB.id == ObjectId(span), fetch_links=True)
if span_db is None:
raise HTTPException(404, detail=f"Span {span} does not exist")

trace = TraceDB(**payload.dict(), user=user)
await trace.create()
return trace_db_to_pydantic(trace)["trace_id"]
trace_db = TraceDB(**payload.dict(), user=user)
await trace_db.create()
return str(trace_db.id)


async def get_trace_single(trace_id: str, **kwargs: dict) -> Trace:
Expand All @@ -90,10 +90,11 @@ async def get_trace_single(trace_id: str, **kwargs: dict) -> Trace:
"""

user = await db_manager.get_user(user_uid=kwargs["uid"])
query_expressions = (TraceDB.id == ObjectId(trace_id), TraceDB.user == user.id)

# Get trace
trace = await TraceDB.find_one(query_expressions)
trace = await TraceDB.find_one(
TraceDB.id == ObjectId(trace_id), TraceDB.user.id == user.id, fetch_links=True
)
return trace_db_to_pydantic(trace)


Expand All @@ -111,14 +112,15 @@ async def trace_status_update(
"""

user = await db_manager.get_user(user_uid=kwargs["uid"])
query_expressions = (TraceDB.id == ObjectId(trace_id), TraceDB.user == user.id)

# Get trace
trace = await TraceDB.find_one(query_expressions)
trace = await TraceDB.find_one(
TraceDB.id == ObjectId(trace_id), TraceDB.user.id == user.id
)

# Update and save trace
trace.status = payload.status
await trace.create()
await trace.save()
return True


Expand Down Expand Up @@ -148,10 +150,11 @@ async def get_trace_spans(trace_id: str, **kwargs: dict) -> List[Span]:
"""

user = await db_manager.get_user(user_uid=kwargs["uid"])
query_expressions = (TraceDB.id == ObjectId(trace_id), TraceDB.user == user.id)

# Get trace
trace = await TraceDB.find_one(query_expressions)
trace = await TraceDB.find_one(
TraceDB.id == ObjectId(trace_id), TraceDB.user.id == user.id, fetch_links=True
)

# Get trace spans
spans = spans_db_to_pydantic(trace.spans)
Expand Down Expand Up @@ -179,14 +182,14 @@ async def add_feedback_to_trace(
created_at=datetime.utcnow(),
)

trace = await TraceDB.find_one(TraceDB.id == ObjectId(trace_id))
trace = await TraceDB.find_one(TraceDB.id == ObjectId(trace_id), fetch_links=True)
if trace.feedbacks is None:
trace.feedbacks = [feedback]
else:
trace.feedbacks.append(feedback)

# Update trace
await trace.create()
await trace.save()
return feedback.uid


Expand All @@ -202,11 +205,10 @@ async def get_trace_feedbacks(trace_id: str, **kwargs: dict) -> List[Feedback]:

user = await db_manager.get_user(user_uid=kwargs["uid"])

# Build query expressions
query_expressions = (TraceDB.id == ObjectId(trace_id), TraceDB.user == user.id)

# Get feedbacks in trace
trace = await TraceDB.find_one(query_expressions)
trace = await TraceDB.find_one(
TraceDB.id == ObjectId(trace_id), TraceDB.user.id == user.id, fetch_links=True
)
feedbacks = [feedback_db_to_pydantic(feedback) for feedback in trace.feedbacks]
return feedbacks

Expand All @@ -226,11 +228,10 @@ async def get_feedback_detail(

user = await db_manager.get_user(user_uid=kwargs["uid"])

# Build query expressions
query_expressions = (TraceDB.id == ObjectId(trace_id), TraceDB.user == user.id)

# Get trace
trace = await TraceDB.find_one(query_expressions)
trace = await TraceDB.find_one(
TraceDB.id == ObjectId(trace_id), TraceDB.user.id == user.id, fetch_links=True
)

# Get feedback
feedback = [
Expand All @@ -257,22 +258,22 @@ async def update_trace_feedback(

user = await db_manager.get_user(user_uid=kwargs["uid"])

# Build query expressions
query_expressions = (TraceDB.id == ObjectId(trace_id), TraceDB.user == user.id)

# Get trace
trace = await TraceDB.find_one(query_expressions)
trace = await TraceDB.find_one(
TraceDB.id == ObjectId(trace_id), TraceDB.user.id == user.id, fetch_links=True
)

# update feedback
feedback_json = {}
for feedback in trace.feedbacks:
if feedback.uid == feedback_id:
feedback.update(payload.dict())
for key, value in payload.dict(exclude_none=True).items():
setattr(feedback, key, value)
feedback_json = feedback.dict()
break

# Save feedback in trace and return a copy
await trace.create()
await trace.save()

# Replace key and transform into a pydantic representation
feedback_json["feedback_id"] = feedback_json.pop("uid")
Expand Down

0 comments on commit 17190f2

Please sign in to comment.