Skip to content

Commit

Permalink
Fix error in select, add test
Browse files Browse the repository at this point in the history
We must use a union because when we retrieve roles with a query, we
check against:
1) role name
2) email of associated user for private roles

We factor out this select into a helper method, which we then test
extensively.
  • Loading branch information
jdavcs committed Oct 30, 2024
1 parent 1860e88 commit 16f1d3f
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 32 deletions.
76 changes: 44 additions & 32 deletions lib/galaxy/model/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,38 +159,7 @@ def get_valid_roles(self, trans, item, query=None, page=None, page_limit=None, i
is_public_item = False
# Admins can always choose from all non-deleted roles
if trans.user_is_admin or trans.app.config.expose_user_email:

stmt = select(Role)

if search_query:
stmt = stmt.join(Role.users).join(User) # We need to query against user email
stmt = stmt.where(
or_(Role.name.like(search_query, escape="/"), User.email.like(search_query, escape="/"))
)

stmt = stmt.where(Role.deleted == false())

if not trans.user_is_admin:
# User is not an admin but the configuration exposes all private roles to all users.
stmt = stmt.where(Role.type == Role.types.PRIVATE)

count_stmt = select(func.count()).select_from(stmt)
total_count = trans.sa_session.scalar(count_stmt)

if limit is not None:
# Takes the least number of results from beginning that includes the requested page
stmt = stmt.order_by(Role.name).limit(limit)
page_start = (page * page_limit) - page_limit
page_end = page_start + page_limit
if total_count < page_start + 1:
# Return empty list if there are less results than the requested position
roles = []
else:
roles = trans.sa_session.scalars(stmt).all()
roles = roles[page_start:page_end]
else:
stmt = stmt.order_by(Role.name)
roles = trans.sa_session.scalars(stmt).all()
roles = _get_valid_roles_case1(trans.sa_session, search_query, trans.user_is_admin, limit, page, page_limit)
# Non-admin and public item
elif is_public_item:
# Add the current user's private role
Expand Down Expand Up @@ -1814,3 +1783,46 @@ def is_foreign_key_violation(error):
# If this is a PostgreSQL foreign key error, then error.orig is an instance of psycopg2.errors.ForeignKeyViolation
# and should have an attribute `pgcode` = 23503.
return int(getattr(error.orig, "pgcode", -1)) == 23503


def _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit):
"""Case: trans.user_is_admin or trans.app.config.expose_user_email"""
stmt = select(Role).where(Role.deleted == false())

if not is_admin:
# User is not an admin but the configuration exposes all private roles to all users,
# so only private roles are returned.
stmt = stmt.where(Role.type == Role.types.PRIVATE)

if search_query:
stmt = stmt.where(Role.name.like(search_query, escape="/"))

# Also check against user emails for associated users of private roles ONLY
stmt2 = (
select(Role)
.join(Role.users)
.join(User)
.where(and_(Role.type == Role.types.PRIVATE, User.email.like(search_query, escape="/")))
)
stmt = stmt.union(stmt2)

count_stmt = select(func.count()).select_from(stmt)
total_count = session.scalar(count_stmt)

stmt = stmt.order_by(Role.name)

if limit is not None:
# Takes the least number of results from beginning that includes the requested page
stmt = stmt.limit(limit)
page_start = (page * page_limit) - page_limit
page_end = page_start + page_limit
if total_count < page_start + 1:
# Return empty list if there are less results than the requested position
return []

stmt = select(Role).from_statement(stmt)
roles = session.scalars(stmt).all()
if limit is not None:
roles = roles[page_start:page_end]

return roles
97 changes: 97 additions & 0 deletions test/unit/data/model/db/test_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
get_private_user_role,
get_roles_by_ids,
)
from galaxy.model.security import _get_valid_roles_case1
from . import have_same_elements


Expand Down Expand Up @@ -42,3 +43,99 @@ def test_get_roles_by_ids(session, make_role):
roles2 = get_roles_by_ids(session, ids)
expected = [r1, r2, r3]
have_same_elements(roles2, expected)


def test_get_falid_roles_case1(session, make_user_and_role, make_user, make_role, make_user_role_association):
# Make 3 users with private roles
(
u1,
rp1,
) = make_user_and_role(email="[email protected]")
(
u2,
rp2,
) = make_user_and_role(email="[email protected]")
(
u3,
rp3,
) = make_user_and_role(email="[email protected]")

# Make 2 sharing roles
rs1 = make_role(type="sharing", name="sharing role for u1")
make_user_role_association(user=u1, role=rs1)
rs2 = make_role(type="sharing", name="sharing role for u2")
make_user_role_association(user=u2, role=rs2)

# Make 4 admin roles
ra1 = make_role(type="admin", name="admin role1")
make_user_role_association(user=u1, role=ra1)
make_user_role_association(user=u2, role=ra1)
ra2 = make_role(type="admin", name="admin role2")
make_user_role_association(user=u1, role=ra2)
make_user_role_association(user=u2, role=ra2)

limit, page, page_limit = 1000, 1, 1000

is_admin = True

search_query = None
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 7 # all roles returned

search_query = "foo%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 2
assert rp1 in roles
assert rp2 in roles

search_query = "foo1%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 1
assert roles[0] == rp1

search_query = "sharing%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 2
assert rs1 in roles
assert rs2 in roles

search_query = "sharing role for u1%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 1
assert roles[0] == rs1

search_query = "admin role%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 2
assert ra1 in roles
assert ra2 in roles

search_query = "admin role1%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 1
assert roles[0] == ra1

is_admin = False # non admins should see only private roles

search_query = None
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 3

search_query = "foo%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 2
assert rp1 in roles
assert rp2 in roles

search_query = "foo1%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 1
assert roles[0] == rp1

search_query = "sharing%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 0

search_query = "admin role%"
roles = _get_valid_roles_case1(session, search_query, is_admin, limit, page, page_limit)
assert len(roles) == 0

0 comments on commit 16f1d3f

Please sign in to comment.