diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index c437cb1ee65a..f66df550407d 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -6121,6 +6121,7 @@ def _get_nested_collection_attributes( ] ] = None, inner_filter: Optional[InnerCollectionFilter] = None, + return_statement_only: bool = False, ): collection_attributes = collection_attributes or () element_attributes = element_attributes or () @@ -6141,32 +6142,34 @@ def attribute_columns(column_collection, attributes, nesting_level=None): label_fragment = f"_{nesting_level}" if nesting_level is not None else "" return [getattr(column_collection, a).label(f"{a}{label_fragment}") for a in attributes] - q = ( - db_session.query( + stmt = ( + select( *attribute_columns(dce.c, element_attributes, nesting_level), *attribute_columns(dc.c, collection_attributes, nesting_level), ) .select_from(dce, dc) .join(dce, dce.c.dataset_collection_id == dc.c.id) - .filter(dc.c.id == dataset_collection.id) + .where(dc.c.id == dataset_collection.id) ) + while ":" in depth_collection_type: nesting_level += 1 inner_dc = alias(DatasetCollection) inner_dce = alias(DatasetCollectionElement) order_by_columns.append(inner_dce.c.element_index) - q = q.join( + stmt = stmt.join( inner_dc, and_(inner_dc.c.id == dce.c.child_collection_id, dce.c.dataset_collection_id == dc.c.id) ).outerjoin(inner_dce, inner_dce.c.dataset_collection_id == inner_dc.c.id) - q = q.add_columns( + stmt = stmt.add_columns( *attribute_columns(inner_dce.c, element_attributes, nesting_level), *attribute_columns(inner_dc.c, collection_attributes, nesting_level), ) dce = inner_dce dc = inner_dc depth_collection_type = depth_collection_type.split(":", 1)[1] + if inner_filter: - q = q.filter(inner_filter.produce_filter(dc.c)) + stmt = stmt.where(inner_filter.produce_filter(dc.c)) if ( hda_attributes @@ -6175,27 +6178,39 @@ def attribute_columns(column_collection, attributes, nesting_level=None): or return_entities and not return_entities == (DatasetCollectionElement,) ): - q = q.join(HistoryDatasetAssociation).join(Dataset) + stmt = stmt.join(HistoryDatasetAssociation).join(Dataset) + if dataset_permission_attributes: - q = q.join(DatasetPermissions) - q = ( - q.add_columns(*attribute_columns(HistoryDatasetAssociation, hda_attributes)) - .add_columns(*attribute_columns(Dataset, dataset_attributes)) - .add_columns(*attribute_columns(DatasetPermissions, dataset_permission_attributes)) - ) + stmt = stmt.join(DatasetPermissions) + + stmt = stmt.add_columns(*attribute_columns(HistoryDatasetAssociation, hda_attributes)) + stmt = stmt.add_columns(*attribute_columns(Dataset, dataset_attributes)) + stmt = stmt.add_columns(*attribute_columns(DatasetPermissions, dataset_permission_attributes)) + for entity in return_entities: - q = q.add_entity(entity) + stmt = stmt.add_columns(entity) if entity == DatasetCollectionElement: - q = q.filter(entity.id == dce.c.id) - return q.distinct().order_by(*order_by_columns) + stmt = stmt.where(entity.id == dce.c.id) + + # Since we apply DISTINCT, all columns from the ORDER BY clause must appear in the SELECT clause. + # Note: when using session.scalars(stmt), these added columns are NOT returned (which is the desired behavior). + for col in order_by_columns: + stmt = stmt.add_columns(col) + + stmt = stmt.distinct().order_by(*order_by_columns) + + if return_statement_only: + return stmt + else: + return db_session.scalars(stmt) @property def dataset_states_and_extensions_summary(self): if not hasattr(self, "_dataset_states_and_extensions_summary"): - q = self._get_nested_collection_attributes(hda_attributes=("extension",), dataset_attributes=("state",)) + nca = self._get_nested_collection_attributes(hda_attributes=("extension",), dataset_attributes=("state",)) extensions = set() states = set() - for extension, state in q: + for extension, state in nca: states.add(state) extensions.add(extension) @@ -6209,8 +6224,8 @@ def has_deferred_data(self): has_deferred_data = False if object_session(self): # TODO: Optimize by just querying without returning the states... - q = self._get_nested_collection_attributes(dataset_attributes=("state",)) - for (state,) in q: + nca = self._get_nested_collection_attributes(dataset_attributes=("state",)) + for (state,) in nca: if state == Dataset.states.DEFERRED: has_deferred_data = True break @@ -6231,18 +6246,26 @@ def populated_optimized(self): if ":" not in self.collection_type: _populated_optimized = self.populated_state == DatasetCollection.populated_states.OK else: - q = self._get_nested_collection_attributes( + stmt = self._get_nested_collection_attributes( collection_attributes=("populated_state",), inner_filter=InnerCollectionFilter( "populated_state", operator.__ne__, DatasetCollection.populated_states.OK ), + return_statement_only = True, ) - _populated_optimized = q.session.query(~exists(q.subquery())).scalar() + #stmt = select(~exists().select_from(stmt)) + stmt = select(~exists(stmt)) + + session = object_session(self) + _populated_optimized = session.scalar(stmt) + #q.session.query(~exists(q.subquery())).scalar() self._populated_optimized = _populated_optimized return self._populated_optimized + + @property def populated(self): top_level_populated = self.populated_state == DatasetCollection.populated_states.OK @@ -6253,9 +6276,9 @@ def populated(self): @property def dataset_action_tuples(self): if not hasattr(self, "_dataset_action_tuples"): - q = self._get_nested_collection_attributes(dataset_permission_attributes=("action", "role_id")) + nca = self._get_nested_collection_attributes(dataset_permission_attributes=("action", "role_id")) _dataset_action_tuples = [] - for _dataset_action_tuple in q: + for _dataset_action_tuple in nca: if _dataset_action_tuple[0] is None: continue _dataset_action_tuples.append(_dataset_action_tuple) @@ -6266,7 +6289,7 @@ def dataset_action_tuples(self): @property def element_identifiers_extensions_and_paths(self): - q = self._get_nested_collection_attributes( + nca = self._get_nested_collection_attributes( element_attributes=("element_identifier",), hda_attributes=("extension",), return_entities=(Dataset,) ) return [(row[:-2], row.extension, row.Dataset.get_file_name()) for row in q] @@ -6277,13 +6300,13 @@ def element_identifiers_extensions_paths_and_metadata_files( ) -> List[List[Any]]: results = [] if object_session(self): - q = self._get_nested_collection_attributes( + nca = self._get_nested_collection_attributes( element_attributes=("element_identifier",), hda_attributes=("extension",), return_entities=(HistoryDatasetAssociation, Dataset), ) # element_identifiers, extension, path - for row in q: + for row in nca: result = [row[:-3], row.extension, row.Dataset.get_file_name()] hda = row.HistoryDatasetAssociation result.append(hda.get_metadata_file_paths_and_extensions()) @@ -6430,7 +6453,7 @@ def copy( def replace_failed_elements(self, replacements): hda_id_to_element = dict( - self._get_nested_collection_attributes(return_entities=[DatasetCollectionElement], hda_attributes=["id"]) + self._get_nested_collection_attributes(return_entities=[DatasetCollectionElement], hda_attributes=["id"]).all() ) for failed, replacement in replacements.items(): element = hda_id_to_element.get(failed.id) diff --git a/test/unit/data/test_galaxy_mapping.py b/test/unit/data/test_galaxy_mapping.py index c7a2d4eb8366..67a74536841d 100644 --- a/test/unit/data/test_galaxy_mapping.py +++ b/test/unit/data/test_galaxy_mapping.py @@ -392,18 +392,23 @@ def test_nested_collection_attributes(self): ) self.model.session.add_all([d1, d2, c1, dce1, dce2, c2, dce3, c3, c4, dce4]) self.model.session.flush() - q = c2._get_nested_collection_attributes( + + rows = c2._get_nested_collection_attributes( element_attributes=("element_identifier",), hda_attributes=("extension",), dataset_attributes=("state",) - ) - assert [(r._fields) for r in q] == [ + ).all() + + assert [(r._fields) for r in rows] == [ ("element_identifier_0", "element_identifier_1", "extension", "state"), ("element_identifier_0", "element_identifier_1", "extension", "state"), ] - assert q.all() == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")] - q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation,)) - assert q.all() == [d1, d2] - q = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation, model.Dataset)) - assert q.all() == [(d1, d1.dataset), (d2, d2.dataset)] + assert rows == [("inner_list", "forward", "bam", "new"), ("inner_list", "reverse", "txt", "new")] + + rows = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation,)).all() + assert rows == [d1, d2] + + rows = c2._get_nested_collection_attributes(return_entities=(model.HistoryDatasetAssociation, model.Dataset)).all() + assert rows == [(d1, d1.dataset), (d2, d2.dataset)] + # Assert properties that use _get_nested_collection_attributes return correct content assert c2.dataset_instances == [d1, d2] assert c2.dataset_elements == [dce1, dce2] @@ -422,8 +427,10 @@ def test_nested_collection_attributes(self): assert c3.dataset_instances == [] assert c3.dataset_elements == [] assert c3.dataset_states_and_extensions_summary == (set(), set()) - q = c4._get_nested_collection_attributes(element_attributes=("element_identifier",)) - assert q.all() == [("outer_list", "inner_list", "forward"), ("outer_list", "inner_list", "reverse")] + + rows = c4._get_nested_collection_attributes(element_attributes=("element_identifier",)).all() + assert rows == [("outer_list", "inner_list", "forward"), ("outer_list", "inner_list", "reverse")] + assert c4.dataset_elements == [dce1, dce2] assert c4.element_identifiers_extensions_and_paths == [ (("outer_list", "inner_list", "forward"), "bam", "mock_dataset_14.dat"),