Skip to content

Commit

Permalink
[TMP] Combined commits from galaxyproject#17180
Browse files Browse the repository at this point in the history
Upgrade SQLAlchemy to 2.0

This conflicts with dependency requirements for sqlalchemy-graphene
(used only in toolshed, new WIP client)

Remove RemovedIn20Warning from config

This does not exist in SQLAlchemy 2.0

Update import path for DeclarativeMeta

Move declaration of injected attrs into constructor
Remove unused import

For context: https://github.com/galaxyproject/galaxy/pull/14717/files#r1486979280

Also, remove model attr type hints that conflict with SA2.0

Apply Mapped/mapped_column to model definitions

Included models: galaxy, tool shed, tool shed install
Column types:
DateTime
Integer
Boolan
Unicode
String (Text/TEXT/TrimmedString/VARCHAR)
UUID
Numeric

NOTE on typing of nullability: db schema != python app

- Mapped[datetime] specifies correct type for the python app;
- nullable=True specifies correct mapping to the db schema (that's what
  the CREATE TABLE sql statement will reflect).

mapped_column.nullable takes precedence over typing annotation of
Mapped. So, if we have:

foo: Mapped[str] = mapped_column(String, nullable=True)

- that means that the foo db field will allow NULL, but the python app
  will not allow foo = None. And vice-versa:

bar: Mapped[Optional[str]] = mapped_column(String, nullable=False)

- the bar db field is NOT NULL, but bar = None is OK.

This might need to be applied to other column definitions, but for now
this addresses specific mypy errors.

Ref: https://docs.sqlalchemy.org/en/20/orm/declarative_tables.html#mapped-column-derives-the-datatype-and-nullability-from-the-mapped-annotation

Add typing to JSON columns, fix related mypy errors

Columns:
MutableJSONType
JSONType
DoubleEncodedJsonType

TODO: I think we need a type alias for json-typed columns: bytes understand
iteration, but not access by key.

Use correct type hints to define common model attrs

Start applying Mapped to relationship definitions in the model

Remove column declaration from HasTags parent class

Fix SA2.0 error: wrap sql in text()

Fix SA2.0 error: pass bind to create_all

Fix SA2.0 error: use Row._mapping for keyed attribute access

Ref: https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#result-rows-act-like-named-tuples

Fix SA2.0 error: show password in url

SA 1.4: str(url) renders connection string with password
SA 2.0: str(url) renders connection string WITHOUT password
Solution: Use render_as_string(hide_password=False)

Fix SA2.0 error: use attribute_keyed_dict

Replaces attribute_mapped_collection (SA20)

Fix SA2.0 error: make select stmt a subquery

Rename varable to fix mypy

Fix SA2.0 error: explicitly use subquery() for select-from argument

Fix SA2.0 error: replase session.bind with session.get_bind()

Fix SA2.0 error: joinedload does not take str args

Fix use of table model attribute

- Use __table__ (SA attr) instead of table (galaxy attr) on mapped classes
- Drop .table and .table.c where redundant

Fix bug: fix HistoryAudit model

It is not a child of RepresentById becuase it does not and should not have an id attr.

Duplicating the __repr__ definition in the HistoryAudit class is a
temporary fix: a proper fix requires changing all models (id and
__repr__ should be split into 2 mixins): to be done in a follow-up PR.

Fix bug: check if template.fields is not null before iterating

Fix bug: call unique() on result, not select stmt

Fix bug: do not pass subquery to in_

Fix bug/typo: use select_from

Fix bug: if using alias on ORM entity, use __table__ as valid FromClause

Fix bug: HDAH model is not serializable (caught by mypy)

Fix typing error: migrations.base

Fix typing error: managers.secured

This fixed 58 mypy errors!

Fix typing error: session type

Fix typing error: use Session instead of scoped_session

No need to pass around scoped_session as arguments

Fix typing error: sharable

Fix SA2.0 error: sqlalchemy exceptions import; minor mypy fix

Mypy: type-ignore: this is never SessionlessContext

Mypy: use verbose assignment to help mypy

Mypy: add assert stmt

Mypy: add assert to ensure seesion is not None

Calling that method when a User obj is not attached to a session should
not happen.

Mypy: return 0 if no results

Mypy: type-ignore: scoped_session vs. install_model_session

We use the disctinction for DI.

Mypy: refactor to one-liner

Mypy: add assert stmts where we know session returns an object

Mypy: rename wfi_step > wfi_step_sq when it becomes a subquery

Job search refactor: factor out build_job_subquery

Job search refactor: build_stmt_for_hda

Job search refactor: build_stmt_for_ldda

Job search refactor: build_stmt_for_hdca

Job search refactor: build_stmt_for_dce

Job search refactor: rename query >> stmt

Mypy: add anno for Lists; type-ignore for HDAs

Note: type-ignore is due to imperative mapping of HDAs (and LDDAs). This
will be removed once we map those models declaratively

Mypy: managers.histories

Mypy: model.deferred

Mypy: arg passed to template can be None

Mypy: celery tasks

type ignore arg: we need to map DatasetInstance classes declaratively
for that to work correctly.

Mypy: type-ignore hda attr-defined error

Need to map declaratively to remove this

Convert visualization manager index query to SA Core

Mypy: session is not none

Mypy: type-ignore what requires more refactoring

Mypy: type-ignore hda, ldda attrs: need declarative mapping

Also, minor SA2.0 syntax fix

Mypy: type-ignores to handle late evaluation of relationship arguments

Mypy: type-ignore column property assignments (type is correct)

Mypy: typing errors, misc. fixes

Mypy: all statements are reachable

Mypy: need to map hda declaratively, then its parent is model.Base

Fix typing errors: sharable, secured

Fix package mypy errors

Fix SA2.0 error: celery task

1. In 2.0, when the statement contains "returning", the result type is
   ChunkedIteratorResult, which does not have the rowcount attr,
   becuase:
2. result.rowcount should not be used for statements containting the returning clause

Ref: https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.CursorResult.rowcount

Wrap call to ensure session is closed

Otherwise there's an idle transaction left in the database (+locks)

Ensure session is closed on TS Registry load

Same as prev. commit: otherwise db locks are left

Fix SA2.0 error: list arg to select; mypy

Use NullPool for sqlite engines

This restores the behavior under SQLAlchemy 1.4
(Note that we set the pool for sqlite only if it's not an in-memory db

Help mypy: job is never None
  • Loading branch information
jdavcs committed Mar 29, 2024
1 parent 99aec81 commit 5bba5af
Show file tree
Hide file tree
Showing 108 changed files with 2,089 additions and 1,844 deletions.
2 changes: 1 addition & 1 deletion lib/galaxy/app_unittest_utils/galaxy_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, config=None, **kwargs) -> None:
self[ShortTermStorageMonitor] = sts_manager # type: ignore[type-abstract]
self[galaxy_scoped_session] = self.model.context
self.visualizations_registry = MockVisualizationsRegistry()
self.tag_handler = tags.GalaxyTagHandler(self.model.context)
self.tag_handler = tags.GalaxyTagHandler(self.model.session)
self[tags.GalaxyTagHandler] = self.tag_handler
self.quota_agent = quota.DatabaseQuotaAgent(self.model)
self.job_config = Bunch(
Expand Down
45 changes: 17 additions & 28 deletions lib/galaxy/celery/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from sqlalchemy.dialects.postgresql import insert as ps_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from galaxy.model import CeleryUserRateLimit
from galaxy.model.base import transaction
Expand Down Expand Up @@ -70,7 +69,7 @@ def __call__(self, task: Task, task_id, args, kwargs):

@abstractmethod
def calculate_task_start_time(
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
return now

Expand All @@ -81,38 +80,28 @@ class GalaxyTaskBeforeStartUserRateLimitPostgres(GalaxyTaskBeforeStartUserRateLi
We take advantage of efficiencies in its dialect.
"""

_update_stmt = (
update(CeleryUserRateLimit)
.where(CeleryUserRateLimit.user_id == bindparam("userid"))
.values(last_scheduled_time=text("greatest(last_scheduled_time + ':interval second', " ":now) "))
.returning(CeleryUserRateLimit.last_scheduled_time)
)

_insert_stmt = (
ps_insert(CeleryUserRateLimit)
.values(user_id=bindparam("userid"), last_scheduled_time=bindparam("now"))
.returning(CeleryUserRateLimit.last_scheduled_time)
)

_upsert_stmt = _insert_stmt.on_conflict_do_update(
index_elements=["user_id"], set_=dict(last_scheduled_time=bindparam("sched_time"))
)

def calculate_task_start_time( # type: ignore
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
with transaction(sa_session):
result = sa_session.execute(
self._update_stmt, {"userid": user_id, "interval": task_interval_secs, "now": now}
update_stmt = (
update(CeleryUserRateLimit)
.where(CeleryUserRateLimit.user_id == user_id)
.values(last_scheduled_time=text("greatest(last_scheduled_time + ':interval second', " ":now) "))
.returning(CeleryUserRateLimit.last_scheduled_time)
)
if result.rowcount == 0:
result = sa_session.execute(update_stmt, {"interval": task_interval_secs, "now": now}).all()
if not result:
sched_time = now + datetime.timedelta(seconds=task_interval_secs)
result = sa_session.execute(
self._upsert_stmt, {"userid": user_id, "now": now, "sched_time": sched_time}
upsert_stmt = (
ps_insert(CeleryUserRateLimit) # type:ignore[attr-defined]
.values(user_id=user_id, last_scheduled_time=now)
.returning(CeleryUserRateLimit.last_scheduled_time)
.on_conflict_do_update(index_elements=["user_id"], set_=dict(last_scheduled_time=sched_time))
)
for row in result:
return row[0]
result = sa_session.execute(upsert_stmt).all()
sa_session.commit()
return result[0][0]


class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLimit):
Expand All @@ -138,7 +127,7 @@ class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLi
)

def calculate_task_start_time(
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
last_scheduled_time = None
with transaction(sa_session):
Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def set_metadata(
try:
if overwrite:
hda_manager.overwrite_metadata(dataset_instance)
dataset_instance.datatype.set_meta(dataset_instance)
dataset_instance.datatype.set_meta(dataset_instance) # type:ignore [arg-type]
dataset_instance.set_peek()
# Reset SETTING_METADATA state so the dataset instance getter picks the dataset state
dataset_instance.set_metadata_success_state()
Expand Down Expand Up @@ -228,6 +228,7 @@ def setup_fetch_data(
):
tool = cached_create_tool_from_representation(app=app, raw_tool_source=raw_tool_source)
job = sa_session.get(Job, job_id)
assert job
# self.request.hostname is the actual worker name given by the `-n` argument, not the hostname as you might think.
job.handler = self.request.hostname
job.job_runner_name = "celery"
Expand Down Expand Up @@ -260,6 +261,7 @@ def finish_job(
):
tool = cached_create_tool_from_representation(app=app, raw_tool_source=raw_tool_source)
job = sa_session.get(Job, job_id)
assert job
# TODO: assert state ?
mini_job_wrapper = MinimalJobWrapper(job=job, app=app, tool=tool)
mini_job_wrapper.finish("", "")
Expand Down Expand Up @@ -320,6 +322,7 @@ def fetch_data(
task_user_id: Optional[int] = None,
) -> str:
job = sa_session.get(Job, job_id)
assert job
mini_job_wrapper = MinimalJobWrapper(job=job, app=app)
mini_job_wrapper.change_state(model.Job.states.RUNNING, flush=True, job=job)
return abort_when_job_stops(_fetch_data, session=sa_session, job_id=job_id, setup_return=setup_return)
Expand Down
31 changes: 0 additions & 31 deletions lib/galaxy/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,6 @@ class GalaxyAppConfiguration(BaseAppConfiguration, CommonConfigurationMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._override_tempdir(kwargs)
self._configure_sqlalchemy20_warnings(kwargs)
self._process_config(kwargs)
self._set_dependent_defaults()

Expand All @@ -764,36 +763,6 @@ def _set_dependent_defaults(self):
f"{dependent_config_param}, {config_param}"
)

def _configure_sqlalchemy20_warnings(self, kwargs):
"""
This method should be deleted after migration to SQLAlchemy 2.0 is complete.
To enable warnings, set `GALAXY_CONFIG_SQLALCHEMY_WARN_20=1`,
"""
warn = string_as_bool(kwargs.get("sqlalchemy_warn_20", False))
if warn:
import sqlalchemy

sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20 = True
self._setup_sqlalchemy20_warnings_filters()

def _setup_sqlalchemy20_warnings_filters(self):
import warnings

from sqlalchemy.exc import RemovedIn20Warning

# Always display RemovedIn20Warning warnings.
warnings.filterwarnings("always", category=RemovedIn20Warning)
# Optionally, enable filters for specific warnings (raise error, or log, etc.)
# messages = [
# r"replace with warning text to match",
# ]
# for msg in messages:
# warnings.filterwarnings('error', message=msg, category=RemovedIn20Warning)
#
# See documentation:
# https://docs.python.org/3.7/library/warnings.html#the-warnings-filter
# https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#migration-to-2-0-step-three-resolve-all-removedin20warnings

def _load_schema(self):
return AppSchema(GALAXY_CONFIG_SCHEMA_PATH, GALAXY_APP_NAME)

Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/dependencies/pinned-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ sniffio==1.3.1 ; python_version >= "3.8" and python_version < "3.13"
social-auth-core==4.5.3 ; python_version >= "3.8" and python_version < "3.13"
sortedcontainers==2.4.0 ; python_version >= "3.8" and python_version < "3.13"
spython==0.3.13 ; python_version >= "3.8" and python_version < "3.13"
sqlalchemy==1.4.52 ; python_version >= "3.8" and python_version < "3.13"
sqlalchemy==2.0.25 ; python_version >= "3.8" and python_version < "3.13"
sqlitedict==2.1.0 ; python_version >= "3.8" and python_version < "3.13"
sqlparse==0.4.4 ; python_version >= "3.8" and python_version < "3.13"
starlette-context==0.3.6 ; python_version >= "3.8" and python_version < "3.13"
Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,9 @@ def galaxy_url(self):
return self.get_destination_configuration("galaxy_infrastructure_url")

def get_job(self) -> model.Job:
return self.sa_session.get(Job, self.job_id)
job = self.sa_session.get(Job, self.job_id)
assert job
return job

def get_id_tag(self):
# For compatibility with drmaa, which uses job_id right now, and TaskWrapper
Expand Down Expand Up @@ -1552,7 +1554,7 @@ def change_state(self, state, info=False, flush=True, job=None):
def get_state(self) -> str:
job = self.get_job()
self.sa_session.refresh(job)
return job.state
return job.state # type:ignore[return-value]

def set_runner(self, runner_url, external_id):
log.warning("set_runner() is deprecated, use set_job_destination()")
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def _one_with_recast_errors(self, query: Query) -> U:
# overridden to raise serializable errors
try:
return query.one()
except sqlalchemy.orm.exc.NoResultFound:
except sqlalchemy.exc.NoResultFound:
raise exceptions.ObjectNotFound(f"{self.model_class.__name__} not found")
except sqlalchemy.orm.exc.MultipleResultsFound:
except sqlalchemy.exc.MultipleResultsFound:
raise exceptions.InconsistentDatabase(f"found more than one {self.model_class.__name__}")

# NOTE: at this layer, all ids are expected to be decoded and in int form
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def get_collection_contents(self, trans: ProvidesAppContext, parent_id, limit=No
def _get_collection_contents_qry(self, parent_id, limit=None, offset=None):
"""Build query to find first level of collection contents by containing collection parent_id"""
DCE = model.DatasetCollectionElement
qry = Query(DCE).filter(DCE.dataset_collection_id == parent_id)
qry = Query(DCE).filter(DCE.dataset_collection_id == parent_id) # type:ignore[var-annotated]
qry = qry.order_by(DCE.element_index)
qry = qry.options(
joinedload(model.DatasetCollectionElement.child_collection), joinedload(model.DatasetCollectionElement.hda)
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def purge_datasets(self, request: PurgeDatasetsTaskRequest):
self.error_unless_dataset_purge_allowed()
with self.session().begin():
for dataset_id in request.dataset_ids:
dataset: Dataset = self.session().get(Dataset, dataset_id)
if dataset.user_can_purge:
dataset: Optional[Dataset] = self.session().get(Dataset, dataset_id)
if dataset and dataset.user_can_purge:
try:
dataset.full_delete()
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/dbkeys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
)

from sqlalchemy import select
from sqlalchemy.orm import Session

from galaxy.model import HistoryDatasetAssociation
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.util import (
galaxy_directory,
sanitize_lists_to_string,
Expand Down Expand Up @@ -166,6 +166,6 @@ def get_chrom_info(self, dbkey, trans=None, custom_build_hack_get_len_from_fasta
return (chrom_info, db_dataset)


def get_len_files_by_history(session: Session, history_id: int):
def get_len_files_by_history(session: galaxy_scoped_session, history_id: int):
stmt = select(HistoryDatasetAssociation).filter_by(history_id=history_id, extension="len", deleted=False)
return session.scalars(stmt)
6 changes: 3 additions & 3 deletions lib/galaxy/managers/export_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
and_,
select,
)
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm.scoping import scoped_session

from galaxy.exceptions import ObjectNotFound
Expand Down Expand Up @@ -44,7 +44,7 @@ def set_export_association_metadata(self, export_association_id: int, export_met
export_association: StoreExportAssociation = self.session.execute(stmt).scalars().one()
except NoResultFound:
raise ObjectNotFound("Cannot set export metadata. Reason: Export association not found")
export_association.export_metadata = export_metadata.json()
export_association.export_metadata = export_metadata.json() # type:ignore[assignment]
with transaction(self.session):
self.session.commit()

Expand Down Expand Up @@ -76,4 +76,4 @@ def get_object_exports(
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
return self.session.execute(stmt).scalars()
return self.session.execute(stmt).scalars() # type:ignore[return-value]
9 changes: 5 additions & 4 deletions lib/galaxy/managers/folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
or_,
select,
)
from sqlalchemy.orm import aliased
from sqlalchemy.orm.exc import (
from sqlalchemy.exc import (
MultipleResultsFound,
NoResultFound,
)
from sqlalchemy.orm import aliased

from galaxy import (
model,
Expand Down Expand Up @@ -505,7 +505,7 @@ def _get_contained_datasets_statement(
stmt = stmt.where(
or_(
func.lower(ldda.name).contains(search_text, autoescape=True),
func.lower(ldda.message).contains(search_text, autoescape=True),
func.lower(ldda.message).contains(search_text, autoescape=True), # type:ignore[attr-defined]
)
)
sort_column = LDDA_SORT_COLUMN_MAP[payload.order_by](ldda, associated_dataset)
Expand Down Expand Up @@ -536,7 +536,7 @@ def _filter_by_include_deleted(

def build_folder_path(
self, sa_session: galaxy_scoped_session, folder: model.LibraryFolder
) -> List[Tuple[str, str]]:
) -> List[Tuple[int, Optional[str]]]:
"""
Returns the folder path from root to the given folder.
Expand All @@ -546,6 +546,7 @@ def build_folder_path(
path_to_root = [(current_folder.id, current_folder.name)]
while current_folder.parent_id is not None:
parent_folder = sa_session.get(LibraryFolder, current_folder.parent_id)
assert parent_folder
current_folder = parent_folder
path_to_root.insert(0, (current_folder.id, current_folder.name))
return path_to_root
Expand Down
9 changes: 6 additions & 3 deletions lib/galaxy/managers/forms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy import select
from sqlalchemy.orm import exc as sqlalchemy_exceptions
from sqlalchemy.exc import (
MultipleResultsFound,
NoResultFound,
)

from galaxy.exceptions import (
InconsistentDatabase,
Expand Down Expand Up @@ -59,9 +62,9 @@ def get(self, trans: ProvidesUserContext, form_id: int) -> FormDefinitionCurrent
try:
stmt = select(FormDefinitionCurrent).where(FormDefinitionCurrent.id == form_id)
form = self.session().execute(stmt).scalar_one()
except sqlalchemy_exceptions.MultipleResultsFound:
except MultipleResultsFound:
raise InconsistentDatabase("Multiple forms found with the same id.")
except sqlalchemy_exceptions.NoResultFound:
except NoResultFound:
raise RequestParameterInvalidException("No accessible form found with the id provided.")
except Exception as e:
raise InternalServerError(f"Error loading from the database.{unicodify(e)}")
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/genomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _create_genome_filter(model_class=None):
if self.database_connection.startswith("postgres"):
column = text("convert_from(metadata, 'UTF8')::json ->> 'dbkey'")
else:
column = func.json_extract(model_class.table.c._metadata, "$.dbkey")
column = func.json_extract(model_class.table.c._metadata, "$.dbkey") # type:ignore[assignment]
lower_val = val.lower() # Ignore case
# dbkey can either be "hg38" or '["hg38"]', so we need to check both
if op == "eq":
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/group_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
)

from sqlalchemy import select
from sqlalchemy.orm import Session

from galaxy import model
from galaxy.exceptions import ObjectNotFound
from galaxy.managers.context import ProvidesAppContext
from galaxy.model import GroupRoleAssociation
from galaxy.model.base import transaction
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.structured_app import MinimalManagerApp

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,7 +93,7 @@ def _remove_role_from_group(self, trans: ProvidesAppContext, group_role: model.G
trans.sa_session.commit()


def get_group_role(session: Session, group, role) -> Optional[GroupRoleAssociation]:
def get_group_role(session: galaxy_scoped_session, group, role) -> Optional[GroupRoleAssociation]:
stmt = (
select(GroupRoleAssociation).where(GroupRoleAssociation.group == group).where(GroupRoleAssociation.role == role)
)
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/group_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
)

from sqlalchemy import select
from sqlalchemy.orm import Session

from galaxy import model
from galaxy.exceptions import ObjectNotFound
Expand All @@ -15,6 +14,7 @@
UserGroupAssociation,
)
from galaxy.model.base import transaction
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.structured_app import MinimalManagerApp

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,7 +96,7 @@ def _remove_user_from_group(self, trans: ProvidesAppContext, group_user: model.U
trans.sa_session.commit()


def get_group_user(session: Session, user, group) -> Optional[UserGroupAssociation]:
def get_group_user(session: galaxy_scoped_session, user, group) -> Optional[UserGroupAssociation]:
stmt = (
select(UserGroupAssociation).where(UserGroupAssociation.user == user).where(UserGroupAssociation.group == group)
)
Expand Down
Loading

0 comments on commit 5bba5af

Please sign in to comment.