Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[23.2] Fix social_core methods #17530

Merged
merged 2 commits into from
Feb 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 95 additions & 19 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
@@ -9526,6 +9526,11 @@ def save(self):

@classmethod
def store(cls, server_url, association):
"""
Create an Association instance
(Required by social_core.storage.AssociationMixin interface)
"""

def get_or_create():
stmt = select(PSAAssociation).filter_by(server_url=server_url, handle=association.handle).limit(1)
assoc = cls.sa_session.scalars(stmt).first()
@@ -9542,11 +9547,19 @@ def get_or_create():

@classmethod
def get(cls, *args, **kwargs):
"""
Get an Association instance
(Required by social_core.storage.AssociationMixin interface)
"""
stmt = select(PSAAssociation).filter_by(*args, **kwargs)
return cls.sa_session.scalars(stmt).all()

@classmethod
def remove(cls, ids_to_delete):
"""
Remove an Association instance
(Required by social_core.storage.AssociationMixin interface)
"""
stmt = (
delete(PSAAssociation)
.where(PSAAssociation.id.in_(ids_to_delete))
@@ -9577,6 +9590,9 @@ def save(self):

@classmethod
def get_code(cls, code):
"""
(Required by social_core.storage.CodeMixin interface)
"""
stmt = select(PSACode).where(PSACode.code == code).limit(1)
return cls.sa_session.scalars(stmt).first()

@@ -9604,6 +9620,10 @@ def save(self):

@classmethod
def use(cls, server_url, timestamp, salt):
"""
Create a Nonce instance
(Required by social_core.storage.NonceMixin interface)
"""
try:
stmt = select(PSANonce).where(server_url=server_url, timestamp=timestamp, salt=salt).limit(1)
return cls.sa_session.scalars(stmt).first()
@@ -9640,11 +9660,17 @@ def save(self):

@classmethod
def load(cls, token):
"""
(Required by social_core.storage.PartialMixin interface)
"""
stmt = select(PSAPartial).where(PSAPartial.token == token).limit(1)
return cls.sa_session.scalars(stmt).first()

@classmethod
def destroy(cls, token):
"""
(Required by social_core.storage.PartialMixin interface)
"""
partial = cls.load(token)
if partial:
session = cls.sa_session
@@ -9695,30 +9721,63 @@ def save(self):
with transaction(self.sa_session):
self.sa_session.commit()

@classmethod
def username_max_length(cls):
# Note: This is the maximum field length set for the username column of the galaxy_user table.
# A better alternative is to retrieve this number from the table, instead of this const value.
return 255

@classmethod
def changed(cls, user):
"""
The given user instance is ready to be saved
(Required by social_core.storage.UserMixin interface)
"""
cls.sa_session.add(user)
with transaction(cls.sa_session):
cls.sa_session.commit()

@classmethod
def get_username(cls, user):
"""
Return the username for given user
(Required by social_core.storage.UserMixin interface)
"""
return getattr(user, "username", None)

@classmethod
def user_model(cls):
"""
Return the user model
(Required by social_core.storage.UserMixin interface)
"""
return User

@classmethod
def username_max_length(cls):
"""
Return the max length for username
(Required by social_core.storage.UserMixin interface)
"""
# Note: This is the maximum field length set for the username column of the galaxy_user table.
# A better alternative is to retrieve this number from the table, instead of this const value.
return 255

@classmethod
def user_exists(cls, *args, **kwargs):
"""
Return True/False if a User instance exists with the given arguments.
Arguments are directly passed to filter() manager method.
(Required by social_core.storage.UserMixin interface)
"""
stmt_user = select(User).filter_by(*args, **kwargs)
stmt_count = select(func.count()).select_from(stmt_user)
return cls.sa_session.scalar(stmt_count) > 0

@classmethod
def create_user(cls, *args, **kwargs):
"""
This is used by PSA authnz, do not use directly.
Prefer using the user manager.
(Required by social_core.storage.UserMixin interface)
"""
instance = User(*args, **kwargs)
if cls.email_exists(instance.email):
model = cls.user_model()
instance = model(*args, **kwargs)
if cls.get_users_by_email(instance.email):
raise Exception(f"User with this email '{instance.email}' already exists.")
instance.set_random_password()
cls.sa_session.add(instance)
@@ -9728,33 +9787,50 @@ def create_user(cls, *args, **kwargs):

@classmethod
def get_user(cls, pk):
return UserAuthnzToken.sa_session.get(User, pk)
"""
Return user instance for given id
(Required by social_core.storage.UserMixin interface)
"""
return cls.sa_session.get(User, pk)

@classmethod
def email_exists(cls, email):
stmt = select(User).where(func.lower(User.email) == email.lower()).limit(1)
return bool(cls.sa_session.scalars(stmt).first())
def get_users_by_email(cls, email):
"""
Return users instances for given email address
(Required by social_core.storage.UserMixin interface)
"""
stmt = select(User).where(func.lower(User.email) == email.lower())
return cls.sa_session.scalars(stmt).all()

@classmethod
def get_social_auth(cls, provider, uid):
"""
Return UserSocialAuth for given provider and uid
(Required by social_core.storage.UserMixin interface)
"""
uid = str(uid)
try:
stmt = select(UserAuthnzToken).filter_by(provider=provider, uid=uid).limit(1)
return cls.sa_session.scalars(stmt).first()
except IndexError:
return None
stmt = select(cls).filter_by(provider=provider, uid=uid).limit(1)
return cls.sa_session.scalars(stmt).first()

@classmethod
def get_social_auth_for_user(cls, user, provider=None, id=None):
stmt = select(UserAuthnzToken).filter_by(user_id=user.id)
"""
Return all the UserSocialAuth instances for given user
(Required by social_core.storage.UserMixin interface)
"""
stmt = select(cls).filter_by(user_id=user.id)
if provider:
stmt = stmt.filter_by(provider=provider)
if id:
stmt = stmt.filter_by(id=id)
return cls.sa_session.scalars(stmt)
return cls.sa_session.scalars(stmt).all()

@classmethod
def create_social_auth(cls, user, uid, provider):
"""
Create a UserSocialAuth instance for given user
(Required by social_core.storage.UserMixin interface)
"""
uid = str(uid)
instance = cls(user=user, uid=uid, provider=provider)
cls.sa_session.add(instance)