Skip to content

Commit

Permalink
make events query more descriptive
Browse files Browse the repository at this point in the history
  • Loading branch information
stufisher committed Aug 9, 2022
1 parent 8763fe4 commit 66e4d78
Showing 1 changed file with 78 additions and 48 deletions.
126 changes: 78 additions & 48 deletions pyispyb/core/modules/events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Any
from dataclasses import dataclass, field
from typing import Any, List, Optional
import os

import sqlalchemy
Expand All @@ -16,6 +17,41 @@
from ..schemas import events as schema


@dataclass
class EntityType:
# The entity `DataCollection` or `EnergyScan`
entity: sqlalchemy.orm.decl_api.DeclarativeMeta
# How the entity joins to `BLSample` i.e. `DataCollection.blSampleId`
sampleId: "sqlalchemy.Column[Any]"
# Its primary key `dataCollectionId`
key: str
# Any joined entities i.e. `DataCollectionGroup``
joined: Optional[List[sqlalchemy.orm.decl_api.DeclarativeMeta]] = field(
default_factory=list
)


ENTITY_TYPES: dict[str, EntityType] = {
"dc": EntityType(
models.DataCollection,
models.DataCollectionGroup.blSampleId,
"dataCollectionId",
[
models.DataCollection.DataCollectionGroup,
],
),
"robot": EntityType(
models.RobotAction, models.RobotAction.blsampleId, "robotActionId"
),
"xrf": EntityType(
models.XFEFluorescenceSpectrum,
models.XFEFluorescenceSpectrum.blSampleId,
"xfeFluorescenceSpectrumId",
),
"es": EntityType(models.EnergyScan, models.EnergyScan.blSampleId, "energyScanId"),
}


def with_sample(
query: "sqlalchemy.orm.Query[Any]",
column: "sqlalchemy.Column[Any]",
Expand Down Expand Up @@ -69,12 +105,12 @@ def get_events(
if dataCollectionGroupId is None:
duration = sqlalchemy.func.sum(duration)
# Return the first dataCollectionId in a group
_dataCollectionId = sqlalchemy.func.min(models.DataCollection.dataCollectionId) # type: ignore
startTime = sqlalchemy.func.min(models.DataCollection.startTime) # type: ignore
endTime = sqlalchemy.func.max(models.DataCollection.endTime) # type: ignore
_dataCollectionId = sqlalchemy.func.min(models.DataCollection.dataCollectionId)
startTime = sqlalchemy.func.min(models.DataCollection.startTime)
endTime = sqlalchemy.func.max(models.DataCollection.endTime)
dataCollectionCount = sqlalchemy.func.count(
sqlalchemy.func.distinct(models.DataCollection.dataCollectionId)
) # type: ignore
)

queries["dc"] = (
db.session.query(
Expand Down Expand Up @@ -150,14 +186,10 @@ def get_events(
)

# Join sample information
_mapper = {
"dc": models.DataCollectionGroup.blSampleId,
"robot": models.RobotAction.blsampleId,
"xrf": models.XFEFluorescenceSpectrum.blSampleId,
"es": models.EnergyScan.blSampleId,
}
for key, _query in queries.items():
queries[key] = with_sample(_query, _mapper[key], blSampleId, proteinId)
queries[key] = with_sample(
_query, ENTITY_TYPES[key].sampleId, blSampleId, proteinId
)

# Apply permissions
if beamlineGroups:
Expand Down Expand Up @@ -207,6 +239,7 @@ def get_events(
models.DataCollectionGroup.dataCollectionGroupId
)

# Now union the four queries
query: sqlalchemy.orm.Query[Any] = queries["dc"].union_all(
queries["robot"], queries["xrf"], queries["es"]
)
Expand All @@ -215,50 +248,47 @@ def get_events(
query = query.order_by(sqlalchemy.desc("startTime"))
query = page(query, skip=skip, limit=limit)

# Results contains an index of type / id
results = query.all()
results = [r._asdict() for r in results]

ids: dict[str, list[int]] = {}
types: dict[str, list[Any]] = {
"dc": [
models.DataCollection,
"dataCollectionId",
models.DataCollection.DataCollectionGroup,
],
"robot": [models.RobotAction, "robotActionId"],
"xrf": [models.XFEFluorescenceSpectrum, "xfeFluorescenceSpectrumId"],
"es": [models.EnergyScan, "energyScanId"],
}
# Build a list of ids to load based on type, i.e. a list of `dataCollectionId`s
entity_ids: dict[str, list[int]] = {}
for result in results:
for name in types.keys():
for name in ENTITY_TYPES.keys():
if result["type"] == name:
if name not in ids:
ids[name] = []
ids[name].append(result["id"])

type_map = {}
for name, ty in types.items():
if name in ids:
column = getattr(ty[0], ty[1])
if len(ty) > 2:
items = (
db.session.query(ty[0])
.join(ty[2])
.options(contains_eager(ty[2]))
.filter(column.in_(ids[name]))
.all()
)
else:
items = db.session.query(ty[0]).filter(column.in_(ids[name])).all()
type_map[name] = {getattr(item, ty[1]): item for item in items}
if name not in entity_ids:
entity_ids[name] = []
entity_ids[name].append(result["id"])

# Now load the related entities, i.e. load the `DataCollection` or `EnergyScan`
entity_type_map = {}
for name, entity_type in ENTITY_TYPES.items():
if name in entity_ids:
column = getattr(entity_type.entity, entity_type.key)
query = db.session.query(entity_type.entity).filter(
column.in_(entity_ids[name])
)

# If there are joined entities load those too
if entity_type.joined:
for joined_entity in entity_type.joined:
query = query.outerjoin(joined_entity).options(
contains_eager(joined_entity)
)

entity_type_map[name] = {
getattr(entity, entity_type.key): entity for entity in query.all()
}

# Merge the loaded entities back into the index's `Item`
for result in results:
for name, ty in types.items():
if result["type"] == name:
if name in type_map:
result["Item"] = type_map[name][result["id"]]
for entity_type_name in ENTITY_TYPES.keys():
if result["type"] == entity_type_name:
if entity_type_name in entity_type_map:
result["Item"] = entity_type_map[entity_type_name][result["id"]]

if name == "dc":
if entity_type_name == "dc":
_check_snapshots(result["Item"])

return Paged(total=total, results=results, skip=skip, limit=limit)
Expand Down

0 comments on commit 66e4d78

Please sign in to comment.