From cb46a1249322b8e5078195c8d39e94c00ff25863 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 25 Apr 2023 10:44:35 -0700 Subject: [PATCH] signal handling: User-defined interrupt handlers Interrupt handling is a tricky problem, not just in terms of implementation, but in terms of desired behavior: when an interrupt is received, which code should handle it? Julia's current answer to this is effectively to throw an `InterruptException` to the first task to hit a safepoint. While this seems sensible (the code that's running gets interrupted), it only really works for very basic numerical code. In the case that multiple tasks are running concurrently, or when try-catch handlers are registered, this system breaks down, and results in unpredictable behavior. This unpredictable behavior includes: - Interrupting background/runtime tasks which don't want to be interrupted, as they do little bits of important work (and are critical to library runtime functionality) - Interrupting only one task, when multiple coordinating tasks would want to receive the interrupt to safely terminate a computation - Interrupting only one library's task, when multiple libraries really would want to be notified about the interrupt The above behavior makes it nearly impossible to provide reliable Ctrl-C behavior, and results in very confused users who get stuck hitting Ctrl-C continuously, sometimes getting caught in a hang, sometimes triggering unrelated exception handling code they didn't mean to, sometimes getting a segfault, and very rarely getting the behavior they desire (with unpredictable safety of being able to continue using the active session as intended). This commit provides an alternative behavior for interrupts which is more predictable: user code may now register tasks as "interrupt handlers", which will be guaranteed to receive an `InterruptException` whenever the session receives an interrupt signal. Additionally, when any interrupt handlers are registered, no other tasks will receive `InterruptException`s; only the handlers may receive them. This behavior allows one or more libraries to register handler tasks which will all be concurrently awoken to handle each interrupt and do whatever is necessary to safely interrupt any running code; the extent to which other tasks are interrupted is arbitrary and library-defined. For example, GPU libraries like AMDGPU.jl can register a handler to safely interrupt GPU kernels running on all GPU queues and do resource cleanup. Concurrently, a complex runtime like the scheduler in Dagger.jl can register a handler to interrupt running tasks on other workers when possible, and otherwise notify the user that tasks are being shutdown. This change is intended to be non-breaking for simple codes: the previous behavior is maintained when no interrupt handlers are registered. However, once some libraries start adding interrupt handlers, other libraries will need to follow suit to ensure that users can interrupt their computations. --- base/task.jl | 14 +++++++++++ src/gc.c | 1 + src/jl_exported_data.inc | 1 + src/jl_exported_funcs.inc | 1 + src/julia_threads.h | 2 ++ src/signal-handling.c | 49 +++++++++++++++++++++++++++++++++++++++ src/signals-unix.c | 11 +++++++-- src/signals-win.c | 3 ++- src/task.c | 7 ++++-- src/threading.h | 2 ++ 10 files changed, 86 insertions(+), 5 deletions(-) diff --git a/base/task.jl b/base/task.jl index ffe8e5665b041f..8072761a5ea22a 100644 --- a/base/task.jl +++ b/base/task.jl @@ -992,3 +992,17 @@ if Sys.iswindows() else pause() = ccall(:pause, Cvoid, ()) end + +interrupt_handlers() = ccall(:jl_get_interrupt_handlers, Any, ())::Vector{Task} +function register_interrupt_handler(t::Task) + handlers = interrupt_handlers() + if findfirst(==(t), handlers) === nothing + push!(handlers, t) + end + return +end +function unregister_interrupt_handler(t::Task) + handlers = interrupt_handlers() + deleteat!(handlers, findall(==(t), handlers)) + return +end diff --git a/src/gc.c b/src/gc.c index 3afddc4afb3d86..79b9bace3a5b5e 100644 --- a/src/gc.c +++ b/src/gc.c @@ -2679,6 +2679,7 @@ static void gc_mark_roots(jl_gc_markqueue_t *mq) gc_try_claim_and_push(mq, jl_emptytuple_type, NULL); gc_try_claim_and_push(mq, cmpswap_names, NULL); gc_try_claim_and_push(mq, jl_global_roots_table, NULL); + gc_try_claim_and_push(mq, jl_interrupt_handlers, NULL); } // find unmarked objects that need to be finalized from the finalizer list "list". diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 52f6cb11d8c0f6..d388f5473410ae 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -54,6 +54,7 @@ XX(jl_int8_type) \ XX(jl_interconditional_type) \ XX(jl_interrupt_exception) \ + XX(jl_interrupt_handlers) \ XX(jl_intrinsic_type) \ XX(jl_kwcall_func) \ XX(jl_lineinfonode_type) \ diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index d3acb7d2ad92af..3bd4e22e802dfe 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -220,6 +220,7 @@ XX(jl_get_field) \ XX(jl_get_global) \ XX(jl_get_image_file) \ + XX(jl_get_interrupt_handlers) \ XX(jl_get_JIT) \ XX(jl_get_julia_bin) \ XX(jl_get_julia_bindir) \ diff --git a/src/julia_threads.h b/src/julia_threads.h index 6439caa0aa2eed..bb14d503d270b9 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -369,6 +369,8 @@ JL_DLLEXPORT int8_t jl_gc_is_in_finalizer(void); JL_DLLEXPORT void jl_wakeup_thread(int16_t tid); +JL_DLLEXPORT void jl_schedule_task(struct _jl_task_t *task); + #ifdef __cplusplus } #endif diff --git a/src/signal-handling.c b/src/signal-handling.c index e241fd22ecb186..2c0f8f18095ee3 100644 --- a/src/signal-handling.c +++ b/src/signal-handling.c @@ -304,6 +304,55 @@ static void jl_check_profile_autostop(void) } } +JL_DLLEXPORT jl_array_t *jl_interrupt_handlers = NULL; +JL_DLLEXPORT jl_array_t *jl_get_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) { + static jl_datatype_t *jl_array_task_type; + if (!jl_array_task_type) + jl_array_task_type = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_task_type, jl_box_long(1)); + jl_array_t *new_handlers = jl_alloc_array_1d((jl_value_t *)jl_array_task_type, 0); + if (jl_atomic_cmpswap(&jl_interrupt_handlers, &handlers, new_handlers)) { + handlers = new_handlers; + } else { + handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + } + } + assert(handlers); + return handlers; +} +static _Atomic(int) handle_interrupt = 0; +JL_DLLEXPORT void jl_schedule_interrupt_handlers(void) +{ + if (jl_atomic_exchange_relaxed(&handle_interrupt, 0) != 1) + return; + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) + return; + for (int i = 0; i < jl_array_len(handlers); i++) { + jl_task_t *handler = ((jl_task_t **)jl_array_data(handlers))[i]; + assert(jl_is_task(handler)); + if (handler->ptls) + continue; + if (jl_atomic_load_relaxed(&handler->_state) != JL_TASK_STATE_RUNNABLE) + continue; + handler->result = jl_interrupt_exception; + handler->_isexception = 1; + jl_schedule_task(handler); + } +} +static int want_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (handlers && (jl_array_len(handlers) > 0)) { + // Set flag to trigger user handlers on next task switch + jl_atomic_store_relaxed(&handle_interrupt, 1); + return 1; + } + return 0; +} + #if defined(_WIN32) #include "signals-win.c" #else diff --git a/src/signals-unix.c b/src/signals-unix.c index 6ed664199fd2b4..266cb4c0b66237 100644 --- a/src/signals-unix.c +++ b/src/signals-unix.c @@ -527,11 +527,14 @@ void usr2_handler(int sig, siginfo_t *info, void *ctx) jl_atomic_exchange(&ptls->signal_request, 0); // returns -1 if (request == 2) { int force = jl_check_force_sigint(); + if (!force && want_interrupt_handlers()) { + return; + } if (force || (!ptls->defer_signal && ptls->io_wait)) { jl_safepoint_consume_sigint(); + // Force a throw if (force) jl_safe_printf("WARNING: Force throwing a SIGINT\n"); - // Force a throw jl_clear_force_sigint(); jl_throw_in_ctx(ct, jl_interrupt_exception, sig, ctx); } @@ -802,7 +805,7 @@ static void *signal_listener(void *arg) profile = (sig == SIGUSR1); #if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L if (profile && !(info.si_code == SI_TIMER && - info.si_value.sival_ptr == &timerprof)) + info.si_value.sival_ptr == &timerprof)) profile = 0; #endif #elif defined(HAVE_ITIMER) @@ -817,6 +820,10 @@ static void *signal_listener(void *arg) else if (exit_on_sigint) { critical = 1; } + // FIXME: Skip this if force + else if (want_interrupt_handlers()) { + continue; + } else { jl_try_deliver_sigint(); continue; diff --git a/src/signals-win.c b/src/signals-win.c index f20a4d5287669f..c5720d84d19d2a 100644 --- a/src/signals-win.c +++ b/src/signals-win.c @@ -221,7 +221,8 @@ static BOOL WINAPI sigint_handler(DWORD wsig) //This needs winapi types to guara if (!jl_ignore_sigint()) { if (exit_on_sigint) jl_exit(128 + sig); // 128 + SIGINT - jl_try_deliver_sigint(); + if (!want_interrupt_handlers()) + jl_try_deliver_sigint(); } return 1; } diff --git a/src/task.c b/src/task.c index 7373de937b9aeb..ecb21ea2dc58fd 100644 --- a/src/task.c +++ b/src/task.c @@ -621,8 +621,12 @@ JL_NO_ASAN static void ctx_switch(jl_task_t *lastt) sanitizer_finish_switch_fiber(ptls->previous_task, jl_atomic_load_relaxed(&ptls->current_task)); } +JL_DLLIMPORT void jl_schedule_interrupt_handlers(void); + JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER { + jl_schedule_interrupt_handlers(); + jl_task_t *ct = jl_current_task; jl_ptls_t ptls = ct->ptls; jl_task_t *t = ptls->next_task; @@ -997,7 +1001,7 @@ JL_DLLEXPORT void jl_task_wait() jl_apply(&wait_func, 1); ct->world_age = last_age; } - +#endif JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) { static jl_function_t *sched_func = NULL; @@ -1011,7 +1015,6 @@ JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) jl_apply(args, 2); ct->world_age = last_age; } -#endif // Do one-time initializations for task system void jl_init_tasks(void) JL_GC_DISABLED diff --git a/src/threading.h b/src/threading.h index 4df6815124eb9c..c04df9c1c92401 100644 --- a/src/threading.h +++ b/src/threading.h @@ -27,6 +27,8 @@ jl_ptls_t jl_init_threadtls(int16_t tid) JL_NOTSAFEPOINT; void jl_init_threadinginfra(void); void jl_threadfun(void *arg); +extern jl_array_t *jl_interrupt_handlers; + #ifdef __cplusplus } #endif