diff --git a/lib/galaxy/celery/__init__.py b/lib/galaxy/celery/__init__.py index 11432481529c..dd9f26acd91f 100644 --- a/lib/galaxy/celery/__init__.py +++ b/lib/galaxy/celery/__init__.py @@ -144,7 +144,7 @@ def init_fork_pool(): @worker_init.connect def setup_worker_pool(sender=None, conf=None, instance=None, **kwargs): context = get_context("forkserver") - celery_app.fork_pool = pebble.ProcessPool( + get_celery_app().fork_pool = pebble.ProcessPool( max_workers=sender.concurrency, max_tasks=100, initializer=init_fork_pool, context=context ) @@ -152,6 +152,7 @@ def setup_worker_pool(sender=None, conf=None, instance=None, **kwargs): @worker_shutting_down.connect def tear_down_pool(sig, how, exitcode, **kwargs): log.debug("shutting down forkserver pool") + celery_app = get_celery_app() celery_app.fork_pool.stop() celery_app.fork_pool.join(timeout=5) @@ -200,14 +201,15 @@ def wrapper(*args, **kwds): return decorate -def init_celery_app(): - celery_app_kwd: Dict[str, Any] = { - "include": TASKS_MODULES, - "task_default_queue": DEFAULT_TASK_QUEUE, - "task_create_missing_queues": True, - "timezone": "UTC", - } - celery_app = GalaxyCelery("galaxy", **celery_app_kwd) +@lru_cache(maxsize=1) +def get_celery_app() -> GalaxyCelery: + celery_app = GalaxyCelery( + "galaxy", + include=TASKS_MODULES, + task_default_queue=DEFAULT_TASK_QUEUE, + task_create_missing_queues=True, + timezone="UTC", + ) celery_app.set_default() config = get_config() config_celery_app(config, celery_app) @@ -248,6 +250,3 @@ def schedule_task(task, interval): if beat_schedule: celery_app.conf.beat_schedule = beat_schedule - - -celery_app = init_celery_app() diff --git a/lib/galaxy/celery/tasks.py b/lib/galaxy/celery/tasks.py index 3b2e4c6272a7..a8ac50cc4d89 100644 --- a/lib/galaxy/celery/tasks.py +++ b/lib/galaxy/celery/tasks.py @@ -15,8 +15,8 @@ from galaxy import model from galaxy.celery import ( - celery_app, galaxy_task, + get_celery_app, ) from galaxy.config import GalaxyAppConfiguration from galaxy.datatypes import sniff @@ -286,7 +286,7 @@ def is_aborted(session: galaxy_scoped_session, job_id: int): def abort_when_job_stops(function: Callable, session: galaxy_scoped_session, job_id: int, **kwargs) -> Any: if not is_aborted(session, job_id): - future = celery_app.fork_pool.submit( + future = get_celery_app().fork_pool.submit( function, timeout=None, **kwargs, diff --git a/lib/galaxy/jobs/handler.py b/lib/galaxy/jobs/handler.py index 0213a797aab4..6bc29f105944 100644 --- a/lib/galaxy/jobs/handler.py +++ b/lib/galaxy/jobs/handler.py @@ -30,6 +30,7 @@ ) from galaxy import model +from galaxy.celery import get_celery_app from galaxy.exceptions import ObjectNotFound from galaxy.jobs import ( JobDestination, @@ -1253,9 +1254,7 @@ def stop(self, job, job_wrapper): # If we're stopping a task, then the runner_name may be # None, in which case it hasn't been scheduled. if self.app.config.enable_celery_tasks and job.tool_id == "__DATA_FETCH__": - from galaxy.celery import celery_app - - celery_app.control.revoke(job.job_runner_external_id) + get_celery_app().control.revoke(job.job_runner_external_id) if (job_runner_name := job.get_job_runner_name()) is not None: runner_name = job_runner_name.split(":", 1)[0] log.debug(f"Stopping job {job_wrapper.get_id_tag()} in {runner_name} runner") diff --git a/lib/galaxy_test/base/api.py b/lib/galaxy_test/base/api.py index 63e33dd5e90a..b55df5c22816 100644 --- a/lib/galaxy_test/base/api.py +++ b/lib/galaxy_test/base/api.py @@ -62,8 +62,9 @@ def _request_celery_app(self, celery_session_app, celery_config): yield finally: if os.environ.get("GALAXY_TEST_EXTERNAL") is None: - from galaxy.celery import celery_app + from galaxy.celery import get_celery_app + celery_app = get_celery_app() celery_app.fork_pool.stop() celery_app.fork_pool.join(timeout=5) diff --git a/test/unit/app/test_celery.py b/test/unit/app/test_celery.py index db16b23b4b88..a0e594d534f0 100644 --- a/test/unit/app/test_celery.py +++ b/test/unit/app/test_celery.py @@ -1,14 +1,14 @@ from galaxy.celery import ( - celery_app, DEFAULT_TASK_QUEUE, GalaxyCelery, + get_celery_app, TASKS_MODULES, ) from galaxy.config import GalaxyAppConfiguration def test_default_configuration(): - conf = celery_app.conf + conf = get_celery_app().conf galaxy_conf = GalaxyAppConfiguration(override_tempdir=False) assert conf.task_default_queue == DEFAULT_TASK_QUEUE