Skip to content

Commit

Permalink
Replace global `celery_app variable with get_celery_app() function
Browse files Browse the repository at this point in the history
  • Loading branch information
nsoranzo committed Sep 16, 2024
1 parent 236f1b3 commit 4e521fa
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 20 deletions.
23 changes: 11 additions & 12 deletions lib/galaxy/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,15 @@ def init_fork_pool():
@worker_init.connect
def setup_worker_pool(sender=None, conf=None, instance=None, **kwargs):
context = get_context("forkserver")
celery_app.fork_pool = pebble.ProcessPool(
get_celery_app().fork_pool = pebble.ProcessPool(
max_workers=sender.concurrency, max_tasks=100, initializer=init_fork_pool, context=context
)


@worker_shutting_down.connect
def tear_down_pool(sig, how, exitcode, **kwargs):
log.debug("shutting down forkserver pool")
celery_app = get_celery_app()
celery_app.fork_pool.stop()
celery_app.fork_pool.join(timeout=5)

Expand Down Expand Up @@ -200,14 +201,15 @@ def wrapper(*args, **kwds):
return decorate


def init_celery_app():
celery_app_kwd: Dict[str, Any] = {
"include": TASKS_MODULES,
"task_default_queue": DEFAULT_TASK_QUEUE,
"task_create_missing_queues": True,
"timezone": "UTC",
}
celery_app = GalaxyCelery("galaxy", **celery_app_kwd)
@lru_cache(maxsize=1)
def get_celery_app() -> GalaxyCelery:
celery_app = GalaxyCelery(
"galaxy",
include=TASKS_MODULES,
task_default_queue=DEFAULT_TASK_QUEUE,
task_create_missing_queues=True,
timezone="UTC",
)
celery_app.set_default()
config = get_config()
config_celery_app(config, celery_app)
Expand Down Expand Up @@ -248,6 +250,3 @@ def schedule_task(task, interval):

if beat_schedule:
celery_app.conf.beat_schedule = beat_schedule


celery_app = init_celery_app()
4 changes: 2 additions & 2 deletions lib/galaxy/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from galaxy import model
from galaxy.celery import (
celery_app,
galaxy_task,
get_celery_app,
)
from galaxy.config import GalaxyAppConfiguration
from galaxy.datatypes import sniff
Expand Down Expand Up @@ -286,7 +286,7 @@ def is_aborted(session: galaxy_scoped_session, job_id: int):

def abort_when_job_stops(function: Callable, session: galaxy_scoped_session, job_id: int, **kwargs) -> Any:
if not is_aborted(session, job_id):
future = celery_app.fork_pool.submit(
future = get_celery_app().fork_pool.submit(
function,
timeout=None,
**kwargs,
Expand Down
5 changes: 2 additions & 3 deletions lib/galaxy/jobs/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

from galaxy import model
from galaxy.celery import get_celery_app
from galaxy.exceptions import ObjectNotFound
from galaxy.jobs import (
JobDestination,
Expand Down Expand Up @@ -1253,9 +1254,7 @@ def stop(self, job, job_wrapper):
# If we're stopping a task, then the runner_name may be
# None, in which case it hasn't been scheduled.
if self.app.config.enable_celery_tasks and job.tool_id == "__DATA_FETCH__":
from galaxy.celery import celery_app

celery_app.control.revoke(job.job_runner_external_id)
get_celery_app().control.revoke(job.job_runner_external_id)
if (job_runner_name := job.get_job_runner_name()) is not None:
runner_name = job_runner_name.split(":", 1)[0]
log.debug(f"Stopping job {job_wrapper.get_id_tag()} in {runner_name} runner")
Expand Down
3 changes: 2 additions & 1 deletion lib/galaxy_test/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def _request_celery_app(self, celery_session_app, celery_config):
yield
finally:
if os.environ.get("GALAXY_TEST_EXTERNAL") is None:
from galaxy.celery import celery_app
from galaxy.celery import get_celery_app

celery_app = get_celery_app()
celery_app.fork_pool.stop()
celery_app.fork_pool.join(timeout=5)

Expand Down
4 changes: 2 additions & 2 deletions test/unit/app/test_celery.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from galaxy.celery import (
celery_app,
DEFAULT_TASK_QUEUE,
GalaxyCelery,
get_celery_app,
TASKS_MODULES,
)
from galaxy.config import GalaxyAppConfiguration


def test_default_configuration():
conf = celery_app.conf
conf = get_celery_app().conf
galaxy_conf = GalaxyAppConfiguration(override_tempdir=False)

assert conf.task_default_queue == DEFAULT_TASK_QUEUE
Expand Down

0 comments on commit 4e521fa

Please sign in to comment.