Skip to content

Commit

Permalink
Fix SA2.0 usage in managers.users
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Oct 13, 2023
1 parent aa29015 commit 17537e5
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions lib/galaxy/managers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from galaxy.model import (
User,
UserAddress,
UserQuotaUsage,
)
from galaxy.model.base import transaction
Expand Down Expand Up @@ -233,13 +234,8 @@ def purge(self, user, flush=True):
user.username = uname_hash
# Redact user addresses as well
if self.app.config.redact_user_address_during_deletion:
user_addresses = (
self.session()
.query(self.app.model.UserAddress)
.filter(self.app.model.UserAddress.user_id == user.id)
.all()
)
for addr in user_addresses:
stmt = select(UserAddress).where(UserAddress.user_id == user.id)
for addr in self.session().scalars(stmt):
addr.desc = new_secure_hash_v2(addr.desc + pseudorandom_value)
addr.name = new_secure_hash_v2(addr.name + pseudorandom_value)
addr.institution = new_secure_hash_v2(addr.institution + pseudorandom_value)
Expand All @@ -264,7 +260,7 @@ def _error_on_duplicate_email(self, email: str) -> None:
raise exceptions.Conflict("Email must be unique", email=email)

def by_id(self, user_id: int) -> model.User:
return self.app.model.session.query(self.model_class).get(user_id)
return self.app.model.session.get(self.model_class, user_id)

# ---- filters
def by_email(self, email: str, filters=None, **kwargs) -> Optional[model.User]:
Expand All @@ -286,7 +282,8 @@ def by_api_key(self, api_key: str, sa_session=None):
return schema.BootstrapAdminUser()
sa_session = sa_session or self.app.model.session
try:
provided_key = sa_session.query(self.app.model.APIKeys).filter_by(key=api_key, deleted=False).one()
stmt = select(self.app.model.APIKeys).filter_by(key=api_key, deleted=False)
provided_key = sa_session.execute(stmt).scalar_one()
except NoResultFound:
raise exceptions.AuthenticationFailed("Provided API key is not valid.")
if provided_key.user.deleted:
Expand Down Expand Up @@ -363,12 +360,8 @@ def get_user_by_identity(self, identity):
user = get_user_by_email(self.session(), identity, self.model_class)
if not user:
# Try a case-insensitive match on the email
user = (
self.session()
.query(self.model_class)
.filter(func.lower(self.model_class.table.c.email) == identity.lower())
.first()
)
stmt = select(self.model_class).where(func.lower(self.model_class.email) == identity.lower()).limit(1)
user = self.session().scalars(stmt).first()
else:
user = get_user_by_username(self.session(), identity, self.model_class)
return user
Expand Down Expand Up @@ -445,7 +438,7 @@ def change_password(self, trans, password=None, confirm=None, token=None, id=Non
if not token and not id:
return None, "Please provide a token or a user and password."
if token:
token_result = trans.sa_session.query(self.app.model.PasswordResetToken).get(token)
token_result = trans.sa_session.get(self.app.model.PasswordResetToken, token)
if not token_result or not token_result.expiration_time > datetime.utcnow():
return None, "Invalid or expired password reset token, please request a new one."
user = token_result.user
Expand Down Expand Up @@ -483,13 +476,14 @@ def __set_password(self, trans, user, password, confirm):
user.set_password_cleartext(password)
# Invalidate all other sessions
if trans.galaxy_session:
for other_galaxy_session in trans.sa_session.query(self.app.model.GalaxySession).filter(
stmt = select(self.app.model.GalaxySession).where(
and_(
self.app.model.GalaxySession.table.c.user_id == user.id,
self.app.model.GalaxySession.table.c.is_valid == true(),
self.app.model.GalaxySession.table.c.id != trans.galaxy_session.id,
self.app.model.GalaxySession.user_id == user.id,
self.app.model.GalaxySession.is_valid == true(),
self.app.model.GalaxySession.id != trans.galaxy_session.id,
)
):
)
for other_galaxy_session in trans.sa_session.scalars(stmt):
other_galaxy_session.is_valid = False
trans.sa_session.add(other_galaxy_session)
trans.sa_session.add(user)
Expand Down Expand Up @@ -581,11 +575,8 @@ def send_reset_email(self, trans, payload, **kwd):
def get_reset_token(self, trans, email):
reset_user = get_user_by_email(trans.sa_session, email, self.app.model.User)
if not reset_user and email != email.lower():
reset_user = (
trans.sa_session.query(self.app.model.User)
.filter(func.lower(self.app.model.User.table.c.email) == email.lower())
.first()
)
stmt = select(self.app.model.User).where(func.lower(self.app.model.User.email) == email.lower()).limit(1)
reset_user = trans.sa_session.scalars(stmt).first()
if reset_user:
prt = self.app.model.PasswordResetToken(reset_user)
trans.sa_session.add(prt)
Expand Down Expand Up @@ -644,9 +635,11 @@ def get_or_create_remote_user(self, remote_user_email):
for char in [x for x in username if x not in f"{string.ascii_lowercase + string.digits}-."]:
username = username.replace(char, "-")
# Find a unique username - user can change it later
if self.session().query(self.app.model.User).filter_by(username=username).first():
stmt = select(self.app.model.User).filter_by(username=username).limit(1)
if self.session().scalars(stmt).first():
i = 1
while self.session().query(self.app.model.User).filter_by(username=f"{username}-{str(i)}").first():
stmt = select(self.app.model.User).filter_by(username=f"{username}-{str(i)}").limit(1)
while self.session().scalars(stmt).first():
i += 1
username += f"-{str(i)}"
user.username = username
Expand Down

0 comments on commit 17537e5

Please sign in to comment.