Skip to content

Commit

Permalink
Convert Query object to Select object in the _get_nested_collection_a…
Browse files Browse the repository at this point in the history
…ttributes method
  • Loading branch information
jdavcs committed Nov 14, 2023
1 parent 75c49bf commit 9d3fab8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 38 deletions.
79 changes: 51 additions & 28 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6121,6 +6121,7 @@ def _get_nested_collection_attributes(
]
] = None,
inner_filter: Optional[InnerCollectionFilter] = None,
return_statement_only: bool = False,
):
collection_attributes = collection_attributes or ()
element_attributes = element_attributes or ()
Expand All @@ -6141,32 +6142,34 @@ def attribute_columns(column_collection, attributes, nesting_level=None):
label_fragment = f"_{nesting_level}" if nesting_level is not None else ""
return [getattr(column_collection, a).label(f"{a}{label_fragment}") for a in attributes]

q = (
db_session.query(
stmt = (
select(
*attribute_columns(dce.c, element_attributes, nesting_level),
*attribute_columns(dc.c, collection_attributes, nesting_level),
)
.select_from(dce, dc)
.join(dce, dce.c.dataset_collection_id == dc.c.id)
.filter(dc.c.id == dataset_collection.id)
.where(dc.c.id == dataset_collection.id)
)

while ":" in depth_collection_type:
nesting_level += 1
inner_dc = alias(DatasetCollection)
inner_dce = alias(DatasetCollectionElement)
order_by_columns.append(inner_dce.c.element_index)
q = q.join(
stmt = stmt.join(
inner_dc, and_(inner_dc.c.id == dce.c.child_collection_id, dce.c.dataset_collection_id == dc.c.id)
).outerjoin(inner_dce, inner_dce.c.dataset_collection_id == inner_dc.c.id)
q = q.add_columns(
stmt = stmt.add_columns(
*attribute_columns(inner_dce.c, element_attributes, nesting_level),
*attribute_columns(inner_dc.c, collection_attributes, nesting_level),
)
dce = inner_dce
dc = inner_dc
depth_collection_type = depth_collection_type.split(":", 1)[1]

if inner_filter:
q = q.filter(inner_filter.produce_filter(dc.c))
stmt = stmt.where(inner_filter.produce_filter(dc.c))

if (
hda_attributes
Expand All @@ -6175,27 +6178,39 @@ def attribute_columns(column_collection, attributes, nesting_level=None):
or return_entities
and not return_entities == (DatasetCollectionElement,)
):
q = q.join(HistoryDatasetAssociation).join(Dataset)
stmt = stmt.join(HistoryDatasetAssociation).join(Dataset)

if dataset_permission_attributes:
q = q.join(DatasetPermissions)
q = (
q.add_columns(*attribute_columns(HistoryDatasetAssociation, hda_attributes))
.add_columns(*attribute_columns(Dataset, dataset_attributes))
.add_columns(*attribute_columns(DatasetPermissions, dataset_permission_attributes))
)
stmt = stmt.join(DatasetPermissions)

stmt = stmt.add_columns(*attribute_columns(HistoryDatasetAssociation, hda_attributes))
stmt = stmt.add_columns(*attribute_columns(Dataset, dataset_attributes))
stmt = stmt.add_columns(*attribute_columns(DatasetPermissions, dataset_permission_attributes))

for entity in return_entities:
q = q.add_entity(entity)
stmt = stmt.add_columns(entity)
if entity == DatasetCollectionElement:
q = q.filter(entity.id == dce.c.id)
return q.distinct().order_by(*order_by_columns)
stmt = stmt.where(entity.id == dce.c.id)

# Since we apply DISTINCT, all columns from the ORDER BY clause must appear in the SELECT clause.
# Note: when using session.scalars(stmt), these added columns are NOT returned (which is the desired behavior).
for col in order_by_columns:
stmt = stmt.add_columns(col)

stmt = stmt.distinct().order_by(*order_by_columns)

if return_statement_only:
return stmt
else:
return db_session.scalars(stmt)

@property
def dataset_states_and_extensions_summary(self):
if not hasattr(self, "_dataset_states_and_extensions_summary"):
q = self._get_nested_collection_attributes(hda_attributes=("extension",), dataset_attributes=("state",))
nca = self._get_nested_collection_attributes(hda_attributes=("extension",), dataset_attributes=("state",))
extensions = set()
states = set()
for extension, state in q:
for extension, state in nca:
states.add(state)
extensions.add(extension)

Expand All @@ -6209,8 +6224,8 @@ def has_deferred_data(self):
has_deferred_data = False
if object_session(self):
# TODO: Optimize by just querying without returning the states...
q = self._get_nested_collection_attributes(dataset_attributes=("state",))
for (state,) in q:
nca = self._get_nested_collection_attributes(dataset_attributes=("state",))
for (state,) in nca:
if state == Dataset.states.DEFERRED:
has_deferred_data = True
break
Expand All @@ -6231,18 +6246,26 @@ def populated_optimized(self):
if ":" not in self.collection_type:
_populated_optimized = self.populated_state == DatasetCollection.populated_states.OK
else:
q = self._get_nested_collection_attributes(
stmt = self._get_nested_collection_attributes(
collection_attributes=("populated_state",),
inner_filter=InnerCollectionFilter(
"populated_state", operator.__ne__, DatasetCollection.populated_states.OK
),
return_statement_only = True,
)
_populated_optimized = q.session.query(~exists(q.subquery())).scalar()
#stmt = select(~exists().select_from(stmt))
stmt = select(~exists(stmt))

session = object_session(self)
_populated_optimized = session.scalar(stmt)
#q.session.query(~exists(q.subquery())).scalar()

self._populated_optimized = _populated_optimized

return self._populated_optimized



@property
def populated(self):
top_level_populated = self.populated_state == DatasetCollection.populated_states.OK
Expand All @@ -6253,9 +6276,9 @@ def populated(self):
@property
def dataset_action_tuples(self):
if not hasattr(self, "_dataset_action_tuples"):
q = self._get_nested_collection_attributes(dataset_permission_attributes=("action", "role_id"))
nca = self._get_nested_collection_attributes(dataset_permission_attributes=("action", "role_id"))
_dataset_action_tuples = []
for _dataset_action_tuple in q:
for _dataset_action_tuple in nca:
if _dataset_action_tuple[0] is None:
continue
_dataset_action_tuples.append(_dataset_action_tuple)
Expand All @@ -6266,7 +6289,7 @@ def dataset_action_tuples(self):

@property
def element_identifiers_extensions_and_paths(self):
q = self._get_nested_collection_attributes(
nca = self._get_nested_collection_attributes(
element_attributes=("element_identifier",), hda_attributes=("extension",), return_entities=(Dataset,)
)
return [(row[:-2], row.extension, row.Dataset.get_file_name()) for row in q]
Expand All @@ -6277,13 +6300,13 @@ def element_identifiers_extensions_paths_and_metadata_files(
) -> List[List[Any]]:
results = []
if object_session(self):
q = self._get_nested_collection_attributes(
nca = self._get_nested_collection_attributes(
element_attributes=("element_identifier",),
hda_attributes=("extension",),
return_entities=(HistoryDatasetAssociation, Dataset),
)
# element_identifiers, extension, path
for row in q:
for row in nca:
result = [row[:-3], row.extension, row.Dataset.get_file_name()]
hda = row.HistoryDatasetAssociation
result.append(hda.get_metadata_file_paths_and_extensions())
Expand Down Expand Up @@ -6430,7 +6453,7 @@ def copy(

def replace_failed_elements(self, replacements):
hda_id_to_element = dict(
self._get_nested_collection_attributes(return_entities=[DatasetCollectionElement], hda_attributes=["id"])
self._get_nested_collection_attributes(return_entities=[DatasetCollectionElement], hda_attributes=["id"]).all()
)
for failed, replacement in replacements.items():
element = hda_id_to_element.get(failed.id)
Expand Down
27 changes: 17 additions & 10 deletions test/unit/data/test_galaxy_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,18 +392,23 @@ def test_nested_collection_attributes(self):
)
self.model.session.add_all([d1, d2, c1, dce1, dce2, c2, dce3, c3, c4, dce4])
self.model.session.flush()
q = c2._get_nested_collection_attributes(

rows = c2._get_nested_collection_attributes(
element_attributes=("element_identifier",), hda_attributes=("extension",), dataset_attributes=("state",)
)
assert [(r._fields) for r in q] == [
).all()

assert [(r._fields) for r in rows] == [
("element_identifier_0", "element_identifier_1", "extension", "state"),
("element_identifier_0", "element_identifier_1", "extension", "state"),
]
assert q.all() == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")]
q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation,))
assert q.all() == [d1, d2]
q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation, model.Dataset))
assert q.all() == [(d1, d1.dataset), (d2, d2.dataset)]
assert rows == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")]

rows = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation,)).all()
assert rows == [d1, d2]

rows = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation, model.Dataset)).all()
assert rows == [(d1, d1.dataset), (d2, d2.dataset)]

# Assert properties that use _get_nested_collection_attributes return correct content
assert c2.dataset_instances == [d1, d2]
assert c2.dataset_elements == [dce1, dce2]
Expand All @@ -422,8 +427,10 @@ def test_nested_collection_attributes(self):
assert c3.dataset_instances == []
assert c3.dataset_elements == []
assert c3.dataset_states_and_extensions_summary == (set(), set())
q = c4._get_nested_collection_attributes(element_attributes=("element_identifier",))
assert q.all() == [("outer_list", "inner_list", "forward"), ("outer_list", "inner_list", "reverse")]

rows = c4._get_nested_collection_attributes(element_attributes=("element_identifier",)).all()
assert rows == [("outer_list", "inner_list", "forward"), ("outer_list", "inner_list", "reverse")]

assert c4.dataset_elements == [dce1, dce2]
assert c4.element_identifiers_extensions_and_paths == [
(("outer_list", "inner_list", "forward"), "bam", "mock_dataset_14.dat"),
Expand Down

0 comments on commit 9d3fab8

Please sign in to comment.