diff --git a/pyispyb/core/modules/events.py b/pyispyb/core/modules/events.py index 214bbbe0..e163cf97 100644 --- a/pyispyb/core/modules/events.py +++ b/pyispyb/core/modules/events.py @@ -1,4 +1,5 @@ -from typing import Optional, Any +from dataclasses import dataclass, field +from typing import Any, List, Optional import os import sqlalchemy @@ -16,6 +17,41 @@ from ..schemas import events as schema +@dataclass +class EntityType: + # The entity `DataCollection` or `EnergyScan` + entity: sqlalchemy.orm.decl_api.DeclarativeMeta + # How the entity joins to `BLSample` i.e. `DataCollection.blSampleId` + sampleId: "sqlalchemy.Column[Any]" + # Its primary key `dataCollectionId` + key: str + # Any joined entities i.e. `DataCollectionGroup`` + joined: Optional[List[sqlalchemy.orm.decl_api.DeclarativeMeta]] = field( + default_factory=list + ) + + +ENTITY_TYPES: dict[str, EntityType] = { + "dc": EntityType( + models.DataCollection, + models.DataCollectionGroup.blSampleId, + "dataCollectionId", + [ + models.DataCollection.DataCollectionGroup, + ], + ), + "robot": EntityType( + models.RobotAction, models.RobotAction.blsampleId, "robotActionId" + ), + "xrf": EntityType( + models.XFEFluorescenceSpectrum, + models.XFEFluorescenceSpectrum.blSampleId, + "xfeFluorescenceSpectrumId", + ), + "es": EntityType(models.EnergyScan, models.EnergyScan.blSampleId, "energyScanId"), +} + + def with_sample( query: "sqlalchemy.orm.Query[Any]", column: "sqlalchemy.Column[Any]", @@ -69,12 +105,12 @@ def get_events( if dataCollectionGroupId is None: duration = sqlalchemy.func.sum(duration) # Return the first dataCollectionId in a group - _dataCollectionId = sqlalchemy.func.min(models.DataCollection.dataCollectionId) # type: ignore - startTime = sqlalchemy.func.min(models.DataCollection.startTime) # type: ignore - endTime = sqlalchemy.func.max(models.DataCollection.endTime) # type: ignore + _dataCollectionId = sqlalchemy.func.min(models.DataCollection.dataCollectionId) + startTime = sqlalchemy.func.min(models.DataCollection.startTime) + endTime = sqlalchemy.func.max(models.DataCollection.endTime) dataCollectionCount = sqlalchemy.func.count( sqlalchemy.func.distinct(models.DataCollection.dataCollectionId) - ) # type: ignore + ) queries["dc"] = ( db.session.query( @@ -150,14 +186,10 @@ def get_events( ) # Join sample information - _mapper = { - "dc": models.DataCollectionGroup.blSampleId, - "robot": models.RobotAction.blsampleId, - "xrf": models.XFEFluorescenceSpectrum.blSampleId, - "es": models.EnergyScan.blSampleId, - } for key, _query in queries.items(): - queries[key] = with_sample(_query, _mapper[key], blSampleId, proteinId) + queries[key] = with_sample( + _query, ENTITY_TYPES[key].sampleId, blSampleId, proteinId + ) # Apply permissions if beamlineGroups: @@ -207,6 +239,7 @@ def get_events( models.DataCollectionGroup.dataCollectionGroupId ) + # Now union the four queries query: sqlalchemy.orm.Query[Any] = queries["dc"].union_all( queries["robot"], queries["xrf"], queries["es"] ) @@ -215,50 +248,47 @@ def get_events( query = query.order_by(sqlalchemy.desc("startTime")) query = page(query, skip=skip, limit=limit) + # Results contains an index of type / id results = query.all() results = [r._asdict() for r in results] - ids: dict[str, list[int]] = {} - types: dict[str, list[Any]] = { - "dc": [ - models.DataCollection, - "dataCollectionId", - models.DataCollection.DataCollectionGroup, - ], - "robot": [models.RobotAction, "robotActionId"], - "xrf": [models.XFEFluorescenceSpectrum, "xfeFluorescenceSpectrumId"], - "es": [models.EnergyScan, "energyScanId"], - } + # Build a list of ids to load based on type, i.e. a list of `dataCollectionId`s + entity_ids: dict[str, list[int]] = {} for result in results: - for name in types.keys(): + for name in ENTITY_TYPES.keys(): if result["type"] == name: - if name not in ids: - ids[name] = [] - ids[name].append(result["id"]) - - type_map = {} - for name, ty in types.items(): - if name in ids: - column = getattr(ty[0], ty[1]) - if len(ty) > 2: - items = ( - db.session.query(ty[0]) - .join(ty[2]) - .options(contains_eager(ty[2])) - .filter(column.in_(ids[name])) - .all() - ) - else: - items = db.session.query(ty[0]).filter(column.in_(ids[name])).all() - type_map[name] = {getattr(item, ty[1]): item for item in items} + if name not in entity_ids: + entity_ids[name] = [] + entity_ids[name].append(result["id"]) + + # Now load the related entities, i.e. load the `DataCollection` or `EnergyScan` + entity_type_map = {} + for name, entity_type in ENTITY_TYPES.items(): + if name in entity_ids: + column = getattr(entity_type.entity, entity_type.key) + query = db.session.query(entity_type.entity).filter( + column.in_(entity_ids[name]) + ) + + # If there are joined entities load those too + if entity_type.joined: + for joined_entity in entity_type.joined: + query = query.outerjoin(joined_entity).options( + contains_eager(joined_entity) + ) + entity_type_map[name] = { + getattr(entity, entity_type.key): entity for entity in query.all() + } + + # Merge the loaded entities back into the index's `Item` for result in results: - for name, ty in types.items(): - if result["type"] == name: - if name in type_map: - result["Item"] = type_map[name][result["id"]] + for entity_type_name in ENTITY_TYPES.keys(): + if result["type"] == entity_type_name: + if entity_type_name in entity_type_map: + result["Item"] = entity_type_map[entity_type_name][result["id"]] - if name == "dc": + if entity_type_name == "dc": _check_snapshots(result["Item"]) return Paged(total=total, results=results, skip=skip, limit=limit)