From 144289d4b7a440b9d49d8e5a2fc28ed82c25f50e Mon Sep 17 00:00:00 2001 From: John Davis Date: Mon, 25 Sep 2023 17:53:39 -0400 Subject: [PATCH] Fix SA2.0 ORM usage in model.security --- lib/galaxy/model/security.py | 246 ++++++++++++++++------------------- 1 file changed, 111 insertions(+), 135 deletions(-) diff --git a/lib/galaxy/model/security.py b/lib/galaxy/model/security.py index 84c3eb5fb976..1f318dd556af 100644 --- a/lib/galaxy/model/security.py +++ b/lib/galaxy/model/security.py @@ -9,12 +9,29 @@ from sqlalchemy import ( and_, false, + func, not_, or_, + select, ) from sqlalchemy.orm import joinedload import galaxy.model +from galaxy.model import ( + Dataset, + DatasetPermissions, + Group, + GroupRoleAssociation, + HistoryDatasetAssociationDisplayAtAuthorization, + Library, + LibraryDataset, + LibraryDatasetPermissions, + LibraryPermissions, + Role, + User, + UserGroupAssociation, + UserRoleAssociation, +) from galaxy.model.base import transaction from galaxy.security import ( Action, @@ -65,17 +82,12 @@ def _get_npns_roles(self, trans): """ non-private, non-sharing roles """ - return ( - trans.sa_session.query(trans.app.model.Role) - .filter( - and_( - self.model.Role.deleted == false(), - self.model.Role.type != self.model.Role.types.PRIVATE, - self.model.Role.type != self.model.Role.types.SHARING, - ) - ) - .order_by(self.model.Role.name) + stmt = ( + select(Role) + .where(and_(Role.deleted == false(), Role.type != Role.types.PRIVATE, Role.type != Role.types.SHARING)) + .order_by(Role.name) ) + return trans.sa_session.scalars(stmt) def get_all_roles(self, trans, cntrller): admin_controller = cntrller in ["library_admin"] @@ -84,11 +96,8 @@ def get_all_roles(self, trans, cntrller): return self._get_npns_roles(trans) if admin_controller: # The library is public and the user is an admin, so all roles are legitimate - for role in ( - trans.sa_session.query(trans.app.model.Role) - .filter(self.model.Role.deleted == false()) - .order_by(self.model.Role.name) - ): + stmt = select(Role).where(Role.deleted == false()).order_by(Role.name) + for role in trans.sa_session.scalars(stmt): roles.add(role) else: # Add the current user's private role @@ -146,27 +155,30 @@ def get_valid_roles(self, trans, item, query=None, page=None, page_limit=None, i # Admins can always choose from all non-deleted roles if trans.user_is_admin or trans.app.config.expose_user_email: if trans.user_is_admin: - db_query = trans.sa_session.query(trans.app.model.Role).filter(self.model.Role.deleted == false()) + stmt = select(Role).where(Role.deleted == false()) else: # User is not an admin but the configuration exposes all private roles to all users. - db_query = trans.sa_session.query(trans.app.model.Role).filter( - and_(self.model.Role.deleted == false(), self.model.Role.type == self.model.Role.types.PRIVATE) - ) + stmt = select(Role).where(and_(Role.deleted == false(), Role.type == Role.types.PRIVATE)) if search_query: - db_query = db_query.filter(self.model.Role.name.like(search_query, escape="/")) - total_count = db_query.count() + stmt = stmt.where(Role.name.like(search_query, escape="/")) + + count_stmt = select(func.count()).select_from(stmt) + total_count = trans.sa_session.scalar(count_stmt) + if limit is not None: # Takes the least number of results from beginning that includes the requested page - roles = db_query.order_by(self.model.Role.name).limit(limit).all() + stmt = stmt.order_by(Role.name).limit(limit) page_start = (page * page_limit) - page_limit page_end = page_start + page_limit if total_count < page_start + 1: # Return empty list if there are less results than the requested position roles = [] else: + roles = trans.sa_session.scalars(stmt).all() roles = roles[page_start:page_end] else: - roles = db_query.order_by(self.model.Role.name) + stmt = stmt.order_by(Role.name) + roles = trans.sa_session.scalars(stmt).all() # Non-admin and public item elif is_public_item: # Add the current user's private role @@ -324,17 +336,13 @@ def get_actions_for_items(self, trans, action, permission_items): # SM: NB: LibraryDatasets became Datasets for some odd reason. if isinstance(permission_items[0], trans.model.LibraryDataset): ids = [item.library_dataset_id for item in permission_items] - permissions = ( - trans.sa_session.query(trans.model.LibraryDatasetPermissions) - .filter( - and_( - trans.model.LibraryDatasetPermissions.library_dataset_id.in_(ids), - trans.model.LibraryDatasetPermissions.action == action.action, - ) + stmt = select(LibraryDatasetPermissions).where( + and_( + LibraryDatasetPermissions.library_dataset_id.in_(ids), + LibraryDatasetPermissions.action == action.action, ) - .all() ) - + permissions = trans.sa_session.scalars(stmt) # Massage the return data. We will return a list of permissions # for each library dataset. So we initialize the return list to # have an empty list for each dataset. Then each permission is @@ -347,17 +355,11 @@ def get_actions_for_items(self, trans, action, permission_items): ret_permissions[permission.library_dataset_id].append(permission) elif isinstance(permission_items[0], trans.model.Dataset): ids = [item.id for item in permission_items] - permissions = ( - trans.sa_session.query(trans.model.DatasetPermissions) - .filter( - and_( - trans.model.DatasetPermissions.dataset_id.in_(ids), - trans.model.DatasetPermissions.action == action.action, - ) - ) - .all() - ) + stmt = select(DatasetPermissions).where( + and_(DatasetPermissions.dataset_id.in_(ids), DatasetPermissions.action == action.action) + ) + permissions = trans.sa_session.scalars(stmt) # Massage the return data. We will return a list of permissions # for each library dataset. So we initialize the return list to # have an empty list for each dataset. Then each permission is @@ -540,38 +542,34 @@ def get_accessible_libraries(self, trans, user): accessible_libraries = [] current_user_role_ids = [role.id for role in user.all_roles()] library_access_action = self.permitted_actions.LIBRARY_ACCESS.action - restricted_library_ids = [ - lp.library_id - for lp in trans.sa_session.query(trans.model.LibraryPermissions) - .filter(trans.model.LibraryPermissions.table.c.action == library_access_action) - .distinct() - ] - accessible_restricted_library_ids = [ - lp.library_id - for lp in trans.sa_session.query(trans.model.LibraryPermissions).filter( - and_( - trans.model.LibraryPermissions.table.c.action == library_access_action, - trans.model.LibraryPermissions.table.c.role_id.in_(current_user_role_ids), - ) + + stmt = select(LibraryPermissions).where(LibraryPermissions.action == library_access_action).distinct() + restricted_library_ids = [lp.library_id for lp in trans.sa_session.scalars(stmt)] + + stmt = select(LibraryPermissions).where( + and_( + LibraryPermissions.action == library_access_action, + LibraryPermissions.role_id.in_(current_user_role_ids), ) - ] + ) + accessible_restricted_library_ids = [lp.library_id for lp in trans.sa_session.scalars(stmt)] + # Filter to get libraries accessible by the current user. Get both # public libraries and restricted libraries accessible by the current user. - for library in ( - trans.sa_session.query(trans.model.Library) - .filter( + stmt = ( + select(Library) + .where( and_( - trans.model.Library.table.c.deleted == false(), - ( - or_( - not_(trans.model.Library.table.c.id.in_(restricted_library_ids)), - trans.model.Library.table.c.id.in_(accessible_restricted_library_ids), - ) + Library.deleted == false(), + or_( + not_(Library.id.in_(restricted_library_ids)), + Library.id.in_(accessible_restricted_library_ids), ), ) ) - .order_by(trans.app.model.Library.name) - ): + .order_by(Library.name) + ) + for library in trans.sa_session.scalars(stmt): accessible_libraries.append(library) return accessible_libraries @@ -589,12 +587,10 @@ def has_accessible_folders(self, trans, folder, user, roles, search_downward=Tru return False def has_accessible_library_datasets(self, trans, folder, user, roles, search_downward=True): - for library_dataset in trans.sa_session.query(trans.model.LibraryDataset).filter( - and_( - trans.model.LibraryDataset.table.c.deleted == false(), - trans.app.model.LibraryDataset.table.c.folder_id == folder.id, - ) - ): + stmt = select(LibraryDataset).where( + and_(LibraryDataset.deleted == false(), LibraryDataset.folder_id == folder.id) + ) + for library_dataset in trans.sa_session.scalars(stmt): if self.can_access_library_item(roles, library_dataset, user): return True if search_downward: @@ -749,17 +745,14 @@ def create_private_user_role(self, user): return self.get_private_user_role(user) def get_private_user_role(self, user, auto_create=False): - role = ( - self.sa_session.query(self.model.Role) - .filter( - and_( - self.model.UserRoleAssociation.table.c.user_id == user.id, - self.model.Role.id == self.model.UserRoleAssociation.table.c.role_id, - self.model.Role.type == self.model.Role.types.PRIVATE, - ) + stmt = select(Role).where( + and_( + UserRoleAssociation.user_id == user.id, + Role.id == UserRoleAssociation.role_id, + Role.type == Role.types.PRIVATE, ) - .one_or_none() ) + role = self.sa_session.execute(stmt).scalar_one_or_none() if not role: if auto_create: return self.create_private_user_role(user) @@ -770,21 +763,18 @@ def get_private_user_role(self, user, auto_create=False): def get_role(self, name, type=None): type = type or self.model.Role.types.SYSTEM # will raise exception if not found - return ( - self.sa_session.query(self.model.Role) - .filter(and_(self.model.Role.name == name, self.model.Role.type == type)) - .one() - ) + stmt = select(Role).where(and_(Role.name == name, Role.type == type)) + return self.sa_session.execute(stmt).scalar_one() def create_role(self, name, description, in_users, in_groups, create_group_for_role=False, type=None): type = type or self.model.Role.types.SYSTEM role = self.model.Role(name=name, description=description, type=type) self.sa_session.add(role) # Create the UserRoleAssociations - for user in [self.sa_session.query(self.model.User).get(x) for x in in_users]: + for user in [self.sa_session.get(User, x) for x in in_users]: self.associate_user_role(user, role) # Create the GroupRoleAssociations - for group in [self.sa_session.query(self.model.Group).get(x) for x in in_groups]: + for group in [self.sa_session.get(Group, x) for x in in_groups]: self.associate_group_role(group, role) if create_group_for_role: # Create the group @@ -800,12 +790,10 @@ def create_role(self, name, description, in_users, in_groups, create_group_for_r return role, num_in_groups def get_sharing_roles(self, user): - return self.sa_session.query(self.model.Role).filter( - and_( - (self.model.Role.name).like(f"Sharing role for: %{user.email}%"), - self.model.Role.type == self.model.Role.types.SHARING, - ) + stmt = select(Role).where( + and_((Role.name).like(f"Sharing role for: %{user.email}%"), Role.type == Role.types.SHARING) ) + return self.sa_session.scalars(stmt) def user_set_default_permissions( self, @@ -1217,16 +1205,13 @@ def datasets_are_public(self, trans, datasets): datasets_public[dataset_id] = True # Now get all datasets which have DATASET_ACCESS actions: - access_data_perms = ( - trans.sa_session.query(trans.app.model.DatasetPermissions) - .filter( - and_( - trans.app.model.DatasetPermissions.dataset_id.in_(dataset_ids), - trans.app.model.DatasetPermissions.action == self.permitted_actions.DATASET_ACCESS.action, - ) + stmt = select(DatasetPermissions).where( + and_( + DatasetPermissions.dataset_id.in_(dataset_ids), + DatasetPermissions.action == self.permitted_actions.DATASET_ACCESS.action, ) - .all() ) + access_data_perms = trans.sa_session.scalars(stmt) # Every dataset returned has "access" privileges associated with it, # so it's not public. for permission in access_data_perms: @@ -1264,14 +1249,14 @@ def derive_roles_from_access(self, trans, item_id, cntrller, library=False, **kw error = False for k, v in get_permitted_actions(filter="DATASET").items(): # Change for removing the prefix '_in' from the roles select box - in_roles = [self.sa_session.query(self.model.Role).get(x) for x in listify(kwd[k])] + in_roles = [self.sa_session.get(Role, x) for x in listify(kwd[k])] if not in_roles: - in_roles = [self.sa_session.query(self.model.Role).get(x) for x in listify(kwd.get(f"{k}_in", []))] + in_roles = [self.sa_session.get(Role, x) for x in listify(kwd.get(f"{k}_in", []))] if v == self.permitted_actions.DATASET_ACCESS and in_roles: if library: - item = self.sa_session.query(self.model.Library).get(item_id) + item = self.sa_session.get(Library, item_id) else: - item = self.sa_session.query(self.model.Dataset).get(item_id) + item = self.sa_session.get(Dataset, item_id) if (library and not self.library_is_public(item)) or (not library and not self.dataset_is_public(item)): # Ensure that roles being associated with DATASET_ACCESS are a subset of the legitimate roles # derived from the roles associated with the access permission on item if it's not public. This @@ -1387,11 +1372,8 @@ def get_permitted_libraries(self, trans, user, actions): libraries = trans.app.security_agent.get_permitted_libraries( trans, user, [ trans.app.security_agent.permitted_actions.LIBRARY_ADD ] ) """ - all_libraries = ( - trans.sa_session.query(trans.app.model.Library) - .filter(trans.app.model.Library.table.c.deleted == false()) - .order_by(trans.app.model.Library.name) - ) + stmt = select(Library).where(Library.deleted == false()).order_by(Library.name) + all_libraries = trans.sa_session.scalars(stmt) roles = user.all_roles() actions_to_check = actions # The libraries dictionary looks like: { library : '1,2' }, library : '3' } @@ -1520,31 +1502,23 @@ def get_component_associations(self, **kwd): assert len(kwd) == 2, "You must specify exactly 2 Galaxy security components to check for associations." if "dataset" in kwd: if "action" in kwd: - return ( - self.sa_session.query(self.model.DatasetPermissions) + stmt = ( + select(DatasetPermissions) .filter_by(action=kwd["action"].action, dataset_id=kwd["dataset"].id) - .first() + .limit(1) ) + return self.sa_session.scalars(stmt).first() elif "user" in kwd: if "group" in kwd: - return ( - self.sa_session.query(self.model.UserGroupAssociation) - .filter_by(group_id=kwd["group"].id, user_id=kwd["user"].id) - .first() - ) + stmt = select(UserGroupAssociation).filter_by(group_id=kwd["group"].id, user_id=kwd["user"].id).limit(1) + return self.sa_session.scalars(stmt).first() elif "role" in kwd: - return ( - self.sa_session.query(self.model.UserRoleAssociation) - .filter_by(role_id=kwd["role"].id, user_id=kwd["user"].id) - .first() - ) + stmt = select(UserRoleAssociation).filter_by(role_id=kwd["role"].id, user_id=kwd["user"].id).limit(1) + return self.sa_session.scalars(stmt).first() elif "group" in kwd: if "role" in kwd: - return ( - self.sa_session.query(self.model.GroupRoleAssociation) - .filter_by(role_id=kwd["role"].id, group_id=kwd["group"].id) - .first() - ) + stmt = select(GroupRoleAssociation).filter_by(role_id=kwd["role"].id, group_id=kwd["group"].id).limit(1) + return self.sa_session.scalars(stmt).first() raise Exception(f"No valid method of associating provided components: {kwd}") def check_folder_contents(self, user, roles, folder, hidden_folder_ids=""): @@ -1635,11 +1609,12 @@ def allow_action(self, addr, action, **kwd): ]: log.debug("Allowing access to public dataset with hda: %i." % hda.id) return True # dataset has no roles associated with the access permission, thus is already public - hdadaa = ( - self.sa_session.query(self.model.HistoryDatasetAssociationDisplayAtAuthorization) + stmt = ( + select(HistoryDatasetAssociationDisplayAtAuthorization) .filter_by(history_dataset_association_id=hda.id) - .first() + .limit(1) ) + hdadaa = self.sa_session.scalars(stmt).first() if not hdadaa: log.debug( "Denying access to private dataset with hda: %i. No hdadaa record for this dataset." % hda.id @@ -1677,11 +1652,12 @@ def allow_action(self, addr, action, **kwd): raise Exception("The dataset access permission is the only valid permission in the host security agent.") def set_dataset_permissions(self, hda, user, site): - hdadaa = ( - self.sa_session.query(self.model.HistoryDatasetAssociationDisplayAtAuthorization) + stmt = ( + select(HistoryDatasetAssociationDisplayAtAuthorization) .filter_by(history_dataset_association_id=hda.id) - .first() + .limit(1) ) + hdadaa = self.sa_session.scalars(stmt).first() if hdadaa: hdadaa.update_time = datetime.utcnow() else: