diff --git a/base/task.jl b/base/task.jl index ae39a20164436e..c51ef6b7e441cd 100644 --- a/base/task.jl +++ b/base/task.jl @@ -993,3 +993,32 @@ if Sys.iswindows() else pause() = ccall(:pause, Cvoid, ()) end + +""" + register_interrupt_handler(handler::Task) + +Registers the task `handler` to handle interrupts (such as from Ctrl-C). When +an interrupt is received, all registered handler tasks will be scheduled at the +next `yield` call, with an `InterruptException` thrown to them. Once any +handler is registered, the runtime will only throw `InterruptException`s to +handlers, and not to any other task, allowing the handlers to soak up and +safely handle interrupts. + +To unregister a previously-registered handler, use +[`unregister_interrupt_handler`](@ref). + +!!! warn + Note that non-yielding tasks may block interrupt handlers from running; + this means that once an interrupt handler is registered, code like `while + true end` may become un-interruptible. +""" +register_interrupt_handler(handler::Task) = + ccall(:jl_register_interrupt_handler, Cvoid, (Any,), handler) +""" + unregister_interrupt_handler(handler::Task) + +Unregisters the interrupt handler task `handler`; see +[`register_interrupt_handler`](@ref) for further details. +""" +unregister_interrupt_handler(handler::Task) = + ccall(:jl_unregister_interrupt_handler, Cvoid, (Any,), handler) diff --git a/src/gc.c b/src/gc.c index 60b110826ee805..397f267ac4046e 100644 --- a/src/gc.c +++ b/src/gc.c @@ -2953,6 +2953,7 @@ static void gc_mark_roots(jl_gc_markqueue_t *mq) gc_try_claim_and_push(mq, jl_emptytuple_type, NULL); gc_try_claim_and_push(mq, cmpswap_names, NULL); gc_try_claim_and_push(mq, jl_global_roots_table, NULL); + gc_try_claim_and_push(mq, jl_interrupt_handlers, NULL); } // find unmarked objects that need to be finalized from the finalizer list "list". diff --git a/src/init.c b/src/init.c index 36e83fdd9c24d4..aef010812e6353 100644 --- a/src/init.c +++ b/src/init.c @@ -708,6 +708,7 @@ extern jl_mutex_t jl_modules_mutex; extern jl_mutex_t precomp_statement_out_lock; extern jl_mutex_t newly_inferred_mutex; extern jl_mutex_t global_roots_lock; +extern jl_mutex_t interrupt_handlers_lock; static void restore_fp_env(void) { @@ -727,6 +728,7 @@ static void init_global_mutexes(void) { JL_MUTEX_INIT(&global_roots_lock, "global_roots_lock"); JL_MUTEX_INIT(&jl_codegen_lock, "jl_codegen_lock"); JL_MUTEX_INIT(&typecache_lock, "typecache_lock"); + JL_MUTEX_INIT(&interrupt_handlers_lock, "interrupt_handlers_lock"); } JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel) diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 092a48be819307..33c3a89b97bd30 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -54,6 +54,7 @@ XX(jl_int8_type) \ XX(jl_interconditional_type) \ XX(jl_interrupt_exception) \ + XX(jl_interrupt_handlers) \ XX(jl_intrinsic_type) \ XX(jl_kwcall_func) \ XX(jl_lineinfonode_type) \ diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index c09f2aff4cb887..b022b6a75517cd 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -220,6 +220,7 @@ XX(jl_get_field) \ XX(jl_get_global) \ XX(jl_get_image_file) \ + XX(jl_get_interrupt_handlers) \ XX(jl_get_JIT) \ XX(jl_get_julia_bin) \ XX(jl_get_julia_bindir) \ @@ -399,6 +400,7 @@ XX(jl_match_cache_flags) \ XX(jl_read_verify_header) \ XX(jl_realloc) \ + XX(jl_register_interrupt_handler) \ XX(jl_register_newmeth_tracer) \ XX(jl_reshape_array) \ XX(jl_resolve_globals_in_ir) \ @@ -511,6 +513,7 @@ XX(jl_uncompress_argname_n) \ XX(jl_uncompress_ir) \ XX(jl_undefined_var_error) \ + XX(jl_unregister_interrupt_handler) \ XX(jl_value_ptr) \ XX(jl_ver_is_release) \ XX(jl_ver_major) \ diff --git a/src/julia_threads.h b/src/julia_threads.h index 07c722253c7f5b..4ddcf1053b51fb 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -372,6 +372,8 @@ JL_DLLEXPORT int8_t jl_gc_is_in_finalizer(void); JL_DLLEXPORT void jl_wakeup_thread(int16_t tid); +JL_DLLEXPORT void jl_schedule_task(struct _jl_task_t *task); + #ifdef __cplusplus } #endif diff --git a/src/signal-handling.c b/src/signal-handling.c index e241fd22ecb186..c70aefa07633c3 100644 --- a/src/signal-handling.c +++ b/src/signal-handling.c @@ -304,6 +304,76 @@ static void jl_check_profile_autostop(void) } } +JL_DLLEXPORT _Atomic(jl_array_t *) jl_interrupt_handlers JL_GLOBALLY_ROOTED = NULL; +JL_DLLEXPORT jl_array_t *jl_get_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) { + static jl_datatype_t *jl_array_task_type; + if (!jl_array_task_type) + jl_array_task_type = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_task_type, jl_box_long(1)); + jl_array_t *new_handlers = jl_alloc_array_1d((jl_value_t *)jl_array_task_type, 0); + if (jl_atomic_cmpswap(&jl_interrupt_handlers, &handlers, new_handlers)) { + handlers = new_handlers; + } else { + handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + } + } + assert(handlers); + return handlers; +} +jl_mutex_t interrupt_handlers_lock; +static _Atomic(int) handle_interrupt = 0; +JL_DLLEXPORT void jl_schedule_interrupt_handlers(void) +{ + if (jl_atomic_exchange_relaxed(&handle_interrupt, 0) != 1) + return; + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) + return; + for (int i = 0; i < jl_array_len(handlers); i++) { + jl_task_t *handler = ((jl_task_t **)jl_array_data(handlers))[i]; + assert(jl_is_task(handler)); + if (handler->ptls) + continue; + if (jl_atomic_load_relaxed(&handler->_state) != JL_TASK_STATE_RUNNABLE) + continue; + handler->result = jl_interrupt_exception; + handler->_isexception = 1; + jl_schedule_task(handler); + } +} +static int want_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (handlers && (jl_array_len(handlers) > 0)) { + // Set flag to trigger user handlers on next task switch + jl_atomic_store_relaxed(&handle_interrupt, 1); + return 1; + } + return 0; +} +JL_DLLEXPORT void jl_register_interrupt_handler(jl_task_t *handler) +{ + JL_LOCK(&interrupt_handlers_lock); + jl_array_t *handlers = jl_get_interrupt_handlers(); + jl_array_grow_end(handlers, 1); + jl_arrayset(handlers, (jl_value_t *)handler, jl_array_len(handlers)-1); + JL_UNLOCK(&interrupt_handlers_lock); +} +extern void jl_array_del_at(jl_array_t *a, ssize_t idx, size_t dec); +JL_DLLEXPORT void jl_unregister_interrupt_handler(jl_task_t *handler) +{ + JL_LOCK(&interrupt_handlers_lock); + jl_array_t *handlers = jl_get_interrupt_handlers(); + for (int i = jl_array_len(handlers)-1; i >= 0; i--) { + jl_task_t *this_handler = ((jl_task_t **)jl_array_data(handlers))[i]; + if (handler == this_handler) + jl_array_del_at(handlers, i, 1); + } + JL_UNLOCK(&interrupt_handlers_lock); +} + #if defined(_WIN32) #include "signals-win.c" #else diff --git a/src/signals-unix.c b/src/signals-unix.c index 4c21d25d3622c3..4570408233218d 100644 --- a/src/signals-unix.c +++ b/src/signals-unix.c @@ -527,9 +527,9 @@ void usr2_handler(int sig, siginfo_t *info, void *ctx) int force = jl_check_force_sigint(); if (force || (!ptls->defer_signal && ptls->io_wait)) { jl_safepoint_consume_sigint(); + // Force a throw if (force) jl_safe_printf("WARNING: Force throwing a SIGINT\n"); - // Force a throw jl_clear_force_sigint(); jl_throw_in_ctx(ct, jl_interrupt_exception, sig, ctx); } @@ -767,7 +767,7 @@ static void *signal_listener(void *arg) profile = (sig == SIGUSR1); #if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L if (profile && !(info.si_code == SI_TIMER && - info.si_value.sival_ptr == &timerprof)) + info.si_value.sival_ptr == &timerprof)) profile = 0; #endif #endif @@ -780,6 +780,10 @@ static void *signal_listener(void *arg) else if (exit_on_sigint) { critical = 1; } + // FIXME: Skip this if force + else if (want_interrupt_handlers()) { + continue; + } else { jl_try_deliver_sigint(); continue; diff --git a/src/signals-win.c b/src/signals-win.c index 5dd6b34558ca6d..5429c62f8f9aaa 100644 --- a/src/signals-win.c +++ b/src/signals-win.c @@ -221,7 +221,8 @@ static BOOL WINAPI sigint_handler(DWORD wsig) //This needs winapi types to guara if (!jl_ignore_sigint()) { if (exit_on_sigint) jl_exit(128 + sig); // 128 + SIGINT - jl_try_deliver_sigint(); + if (!want_interrupt_handlers()) + jl_try_deliver_sigint(); } return 1; } diff --git a/src/task.c b/src/task.c index 9678cf2f3fe4ef..f8c723e3fd489b 100644 --- a/src/task.c +++ b/src/task.c @@ -621,8 +621,12 @@ JL_NO_ASAN static void ctx_switch(jl_task_t *lastt) sanitizer_finish_switch_fiber(ptls->previous_task, jl_atomic_load_relaxed(&ptls->current_task)); } +JL_DLLIMPORT void jl_schedule_interrupt_handlers(void); + JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER { + jl_schedule_interrupt_handlers(); + jl_task_t *ct = jl_current_task; jl_ptls_t ptls = ct->ptls; jl_task_t *t = ptls->next_task; @@ -1141,7 +1145,7 @@ JL_DLLEXPORT void jl_task_wait() jl_apply(&wait_func, 1); ct->world_age = last_age; } - +#endif JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) { static jl_function_t *sched_func = NULL; @@ -1155,7 +1159,6 @@ JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) jl_apply(args, 2); ct->world_age = last_age; } -#endif // Do one-time initializations for task system void jl_init_tasks(void) JL_GC_DISABLED diff --git a/src/threading.h b/src/threading.h index 40792a2889e44e..bcdc1a36b99817 100644 --- a/src/threading.h +++ b/src/threading.h @@ -28,6 +28,8 @@ void jl_init_threadinginfra(void); void jl_gc_threadfun(void *arg); void jl_threadfun(void *arg); +extern _Atomic(jl_array_t *) jl_interrupt_handlers JL_GLOBALLY_ROOTED; + #ifdef __cplusplus } #endif diff --git a/test/stress.jl b/test/stress.jl index b9fb720f0596ae..38b661064c3379 100644 --- a/test/stress.jl +++ b/test/stress.jl @@ -84,5 +84,34 @@ if !Sys.iswindows() ccall(:jl_gc_safepoint, Cvoid, ()) # wait for SIGINT to arrive end end + + # interrupt handlers + let exc_ref = Ref{Any}() + handler = Threads.@spawn begin + try + wait() + catch exc + exc_ref[] = exc + end + end + yield() # let the handler start + Base.register_interrupt_handler(handler) + ccall(:kill, Cvoid, (Cint, Cint,), getpid(), 2) + for i in 1:10 + Libc.systemsleep(0.1) + yield() # wait for the handler to be run + end + Base.unregister_interrupt_handler(handler) + @test isassigned(exc_ref) && exc_ref[] isa InterruptException + end + + # ensure we revert to original interrupt behavior + @test_throws InterruptException begin + ccall(:kill, Cvoid, (Cint, Cint,), getpid(), 2) + for i in 1:10 + Libc.systemsleep(0.1) + ccall(:jl_gc_safepoint, Cvoid, ()) # wait for SIGINT to arrive + end + end Base.exit_on_sigint(true) end