diff --git a/base/task.jl b/base/task.jl index ffe8e5665b041f..8072761a5ea22a 100644 --- a/base/task.jl +++ b/base/task.jl @@ -992,3 +992,17 @@ if Sys.iswindows() else pause() = ccall(:pause, Cvoid, ()) end + +interrupt_handlers() = ccall(:jl_get_interrupt_handlers, Any, ())::Vector{Task} +function register_interrupt_handler(t::Task) + handlers = interrupt_handlers() + if findfirst(==(t), handlers) === nothing + push!(handlers, t) + end + return +end +function unregister_interrupt_handler(t::Task) + handlers = interrupt_handlers() + deleteat!(handlers, findall(==(t), handlers)) + return +end diff --git a/src/gc.c b/src/gc.c index 3afddc4afb3d86..79b9bace3a5b5e 100644 --- a/src/gc.c +++ b/src/gc.c @@ -2679,6 +2679,7 @@ static void gc_mark_roots(jl_gc_markqueue_t *mq) gc_try_claim_and_push(mq, jl_emptytuple_type, NULL); gc_try_claim_and_push(mq, cmpswap_names, NULL); gc_try_claim_and_push(mq, jl_global_roots_table, NULL); + gc_try_claim_and_push(mq, jl_interrupt_handlers, NULL); } // find unmarked objects that need to be finalized from the finalizer list "list". diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 52f6cb11d8c0f6..d388f5473410ae 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -54,6 +54,7 @@ XX(jl_int8_type) \ XX(jl_interconditional_type) \ XX(jl_interrupt_exception) \ + XX(jl_interrupt_handlers) \ XX(jl_intrinsic_type) \ XX(jl_kwcall_func) \ XX(jl_lineinfonode_type) \ diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index d3acb7d2ad92af..3bd4e22e802dfe 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -220,6 +220,7 @@ XX(jl_get_field) \ XX(jl_get_global) \ XX(jl_get_image_file) \ + XX(jl_get_interrupt_handlers) \ XX(jl_get_JIT) \ XX(jl_get_julia_bin) \ XX(jl_get_julia_bindir) \ diff --git a/src/julia_threads.h b/src/julia_threads.h index 6439caa0aa2eed..bb14d503d270b9 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -369,6 +369,8 @@ JL_DLLEXPORT int8_t jl_gc_is_in_finalizer(void); JL_DLLEXPORT void jl_wakeup_thread(int16_t tid); +JL_DLLEXPORT void jl_schedule_task(struct _jl_task_t *task); + #ifdef __cplusplus } #endif diff --git a/src/signal-handling.c b/src/signal-handling.c index e241fd22ecb186..2c0f8f18095ee3 100644 --- a/src/signal-handling.c +++ b/src/signal-handling.c @@ -304,6 +304,55 @@ static void jl_check_profile_autostop(void) } } +JL_DLLEXPORT jl_array_t *jl_interrupt_handlers = NULL; +JL_DLLEXPORT jl_array_t *jl_get_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) { + static jl_datatype_t *jl_array_task_type; + if (!jl_array_task_type) + jl_array_task_type = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_task_type, jl_box_long(1)); + jl_array_t *new_handlers = jl_alloc_array_1d((jl_value_t *)jl_array_task_type, 0); + if (jl_atomic_cmpswap(&jl_interrupt_handlers, &handlers, new_handlers)) { + handlers = new_handlers; + } else { + handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + } + } + assert(handlers); + return handlers; +} +static _Atomic(int) handle_interrupt = 0; +JL_DLLEXPORT void jl_schedule_interrupt_handlers(void) +{ + if (jl_atomic_exchange_relaxed(&handle_interrupt, 0) != 1) + return; + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) + return; + for (int i = 0; i < jl_array_len(handlers); i++) { + jl_task_t *handler = ((jl_task_t **)jl_array_data(handlers))[i]; + assert(jl_is_task(handler)); + if (handler->ptls) + continue; + if (jl_atomic_load_relaxed(&handler->_state) != JL_TASK_STATE_RUNNABLE) + continue; + handler->result = jl_interrupt_exception; + handler->_isexception = 1; + jl_schedule_task(handler); + } +} +static int want_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (handlers && (jl_array_len(handlers) > 0)) { + // Set flag to trigger user handlers on next task switch + jl_atomic_store_relaxed(&handle_interrupt, 1); + return 1; + } + return 0; +} + #if defined(_WIN32) #include "signals-win.c" #else diff --git a/src/signals-unix.c b/src/signals-unix.c index 6ed664199fd2b4..266cb4c0b66237 100644 --- a/src/signals-unix.c +++ b/src/signals-unix.c @@ -527,11 +527,14 @@ void usr2_handler(int sig, siginfo_t *info, void *ctx) jl_atomic_exchange(&ptls->signal_request, 0); // returns -1 if (request == 2) { int force = jl_check_force_sigint(); + if (!force && want_interrupt_handlers()) { + return; + } if (force || (!ptls->defer_signal && ptls->io_wait)) { jl_safepoint_consume_sigint(); + // Force a throw if (force) jl_safe_printf("WARNING: Force throwing a SIGINT\n"); - // Force a throw jl_clear_force_sigint(); jl_throw_in_ctx(ct, jl_interrupt_exception, sig, ctx); } @@ -802,7 +805,7 @@ static void *signal_listener(void *arg) profile = (sig == SIGUSR1); #if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L if (profile && !(info.si_code == SI_TIMER && - info.si_value.sival_ptr == &timerprof)) + info.si_value.sival_ptr == &timerprof)) profile = 0; #endif #elif defined(HAVE_ITIMER) @@ -817,6 +820,10 @@ static void *signal_listener(void *arg) else if (exit_on_sigint) { critical = 1; } + // FIXME: Skip this if force + else if (want_interrupt_handlers()) { + continue; + } else { jl_try_deliver_sigint(); continue; diff --git a/src/signals-win.c b/src/signals-win.c index f20a4d5287669f..c5720d84d19d2a 100644 --- a/src/signals-win.c +++ b/src/signals-win.c @@ -221,7 +221,8 @@ static BOOL WINAPI sigint_handler(DWORD wsig) //This needs winapi types to guara if (!jl_ignore_sigint()) { if (exit_on_sigint) jl_exit(128 + sig); // 128 + SIGINT - jl_try_deliver_sigint(); + if (!want_interrupt_handlers()) + jl_try_deliver_sigint(); } return 1; } diff --git a/src/task.c b/src/task.c index 7373de937b9aeb..ecb21ea2dc58fd 100644 --- a/src/task.c +++ b/src/task.c @@ -621,8 +621,12 @@ JL_NO_ASAN static void ctx_switch(jl_task_t *lastt) sanitizer_finish_switch_fiber(ptls->previous_task, jl_atomic_load_relaxed(&ptls->current_task)); } +JL_DLLIMPORT void jl_schedule_interrupt_handlers(void); + JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER { + jl_schedule_interrupt_handlers(); + jl_task_t *ct = jl_current_task; jl_ptls_t ptls = ct->ptls; jl_task_t *t = ptls->next_task; @@ -997,7 +1001,7 @@ JL_DLLEXPORT void jl_task_wait() jl_apply(&wait_func, 1); ct->world_age = last_age; } - +#endif JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) { static jl_function_t *sched_func = NULL; @@ -1011,7 +1015,6 @@ JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) jl_apply(args, 2); ct->world_age = last_age; } -#endif // Do one-time initializations for task system void jl_init_tasks(void) JL_GC_DISABLED diff --git a/src/threading.h b/src/threading.h index 4df6815124eb9c..c04df9c1c92401 100644 --- a/src/threading.h +++ b/src/threading.h @@ -27,6 +27,8 @@ jl_ptls_t jl_init_threadtls(int16_t tid) JL_NOTSAFEPOINT; void jl_init_threadinginfra(void); void jl_threadfun(void *arg); +extern jl_array_t *jl_interrupt_handlers; + #ifdef __cplusplus } #endif