From 4ed6f41b7d519694909a1582069da1ac025a249d Mon Sep 17 00:00:00 2001 From: Alan Wu Date: Fri, 23 Feb 2024 23:04:07 -0500 Subject: [PATCH] YJIT: Support splat calls to C methods with -1 arity Usually we deal with splats by speculating that they're of a specific size. In this case, the C method takes a pointer and a length, so we can support changing sizes just fine. --- bootstraptest/test_yjit.rb | 24 +++++++++++++ yjit.c | 45 ++++++++++++++++++++++++ yjit/bindgen/src/main.rs | 2 ++ yjit/src/codegen.rs | 63 ++++++++++++++++++++++++++-------- yjit/src/cruby_bindings.inc.rs | 9 +++++ yjit/src/stats.rs | 3 +- 6 files changed, 131 insertions(+), 15 deletions(-) diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index a2c98c01c939aa..1d98b92d3a425a 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -4634,3 +4634,27 @@ def splat_nil_kw_splat(args) = identity(*args, **nil) empty = [] [1.abs(*empty), 1.abs(**nil), 1.bit_length(*empty, **nil)] } + +# splat into C methods with -1 arity +assert_equal '[[1, 2, 3], [0, 2, 3], [1, 2, 3], [2, 2, 3], [], [], [{}]]', %q{ + class Foo < Array + def push(args) = super(1, *args) + end + + def test_cfunc_vargs_splat(sub_instance, array_class, empty_kw_hash) + splat = [2, 3] + kw_splat = [empty_kw_hash] + [ + sub_instance.push(splat), + array_class[0, *splat, **nil], + array_class[1, *splat, &nil], + array_class[2, *splat, **nil, &nil], + array_class.send(:[], *kw_splat), + # kw_splat disables keywords hash handling + array_class[*kw_splat], + array_class[*kw_splat, **nil], + ] + end + + test_cfunc_vargs_splat(Foo.new, Array, Hash.ruby2_keywords_hash({})) +} diff --git a/yjit.c b/yjit.c index 207974073b0f68..8a52ca8024b352 100644 --- a/yjit.c +++ b/yjit.c @@ -890,6 +890,51 @@ rb_yjit_ruby2_keywords_splat_p(VALUE obj) return FL_TEST_RAW(last, RHASH_PASS_AS_KEYWORDS); } +// Checks to establish preconditions for rb_yjit_splat_varg_cfunc() +VALUE +rb_yjit_splat_varg_checks(VALUE *sp, VALUE splat_array, rb_control_frame_t *cfp) +{ + if (!RB_TYPE_P(splat_array, T_ARRAY)) return Qfalse; + long len = RARRAY_LEN(splat_array); + + // Large splat arrays need a separate allocation + if (len < 0 || len > VM_ARGC_STACK_MAX) return Qfalse; + + // Would we overflow if we put the contents of the array onto the stack? + if (sp + len > (VALUE *)(cfp - 2 * sizeof(*cfp))) return Qfalse; + + return Qtrue; +} + +// Push array elements to the stack for a C method that has a variable number +// of parameters. Returns the number of arguments the splat array contributes. +int +rb_yjit_splat_varg_cfunc(VALUE *stack_splat_array, bool sole_splat) +{ + VALUE splat_array = *stack_splat_array; + int len; + + // We already checked that length fits in `int` + RUBY_ASSERT(RB_TYPE_P(splat_array, T_ARRAY)); + len = (int)RARRAY_LEN(splat_array); + + // If this is a splat call without any keyword arguments, exclude the + // ruby2_keywords hash if it's empty + if (sole_splat && len > 0) { + VALUE last_hash = RARRAY_AREF(splat_array, len - 1); + if (RB_TYPE_P(last_hash, T_HASH) && + FL_TEST_RAW(last_hash, RHASH_PASS_AS_KEYWORDS) && + RHASH_EMPTY_P(last_hash)) { + len--; + } + } + + // Push the contents of the array onto the stack + MEMCPY(stack_splat_array, RARRAY_CONST_PTR(splat_array), VALUE, len); + + return len; +} + // Print the Ruby source location of some ISEQ for debugging purposes void rb_yjit_dump_iseq_loc(const rb_iseq_t *iseq, uint32_t insn_idx) diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs index a1b8cf3a75dd0e..c58df7c3773b90 100644 --- a/yjit/bindgen/src/main.rs +++ b/yjit/bindgen/src/main.rs @@ -460,6 +460,8 @@ fn main() { .allowlist_function("rb_vm_base_ptr") .allowlist_function("rb_ec_stack_check") .allowlist_function("rb_vm_top_self") + .allowlist_function("rb_yjit_splat_varg_checks") + .allowlist_function("rb_yjit_splat_varg_cfunc") // We define VALUE manually, don't import it .blocklist_type("VALUE") diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 41c7adae2c658d..e1317fd6b7b6ff 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -6140,18 +6140,16 @@ fn gen_send_cfunc( let cfunc_argc = unsafe { get_mct_argc(cfunc) }; let mut argc = argc; + // Splat call to a C method that takes `VALUE *` and `len` + let variable_splat = flags & VM_CALL_ARGS_SPLAT != 0 && cfunc_argc == -1; + let block_arg = flags & VM_CALL_ARGS_BLOCKARG != 0; + // If the function expects a Ruby array of arguments if cfunc_argc < 0 && cfunc_argc != -1 { gen_counter_incr(asm, Counter::send_cfunc_ruby_array_varg); return None; } - // We aren't handling a vararg cfuncs with splat currently. - if flags & VM_CALL_ARGS_SPLAT != 0 && cfunc_argc == -1 { - gen_counter_incr(asm, Counter::send_args_splat_cfunc_var_args); - return None; - } - exit_if_kwsplat_non_nil(asm, flags, Counter::send_cfunc_kw_splat_non_nil)?; let kw_splat = flags & VM_CALL_KW_SPLAT != 0; @@ -6217,6 +6215,16 @@ fn gen_send_cfunc( asm.cmp(CFP, stack_limit); asm.jbe(Target::side_exit(Counter::guard_send_se_cf_overflow)); + // Guard for variable length splat call before any modifications to the stack + if variable_splat { + asm_comment!(asm, "guard variable length splat call servicable"); + let sp = asm.ctx.sp_opnd(0); + let splat_array = asm.stack_opnd(i32::from(kw_splat) + i32::from(block_arg)); + let proceed = asm.ccall(rb_yjit_splat_varg_checks as _, vec![sp, splat_array, CFP]); + asm.cmp(proceed, Qfalse.into()); + asm.je(Target::side_exit(Counter::guard_send_cfunc_bad_splat_vargs)); + } + // Number of args which will be passed through to the callee // This is adjusted by the kwargs being combined into a hash. let mut passed_argc = if kw_arg.is_null() { @@ -6242,7 +6250,6 @@ fn gen_send_cfunc( return None; } - let block_arg = flags & VM_CALL_ARGS_BLOCKARG != 0; let block_arg_type = if block_arg { Some(asm.ctx.get_opnd_type(StackOpnd(0))) } else { @@ -6287,9 +6294,9 @@ fn gen_send_cfunc( argc -= 1; } - // push_splat_args does stack manipulation so we can no longer side exit - if flags & VM_CALL_ARGS_SPLAT != 0 { - assert!(cfunc_argc >= 0); + // Splat handling when C method takes a static number of arguments. + // push_splat_args() does stack manipulation so we can no longer side exit + if flags & VM_CALL_ARGS_SPLAT != 0 && cfunc_argc >= 0 { let required_args : u32 = (cfunc_argc as u32).saturating_sub(argc as u32 - 1); // + 1 because we pass self if required_args + 1 >= C_ARG_OPNDS.len() as u32 { @@ -6312,15 +6319,34 @@ fn gen_send_cfunc( handle_opt_send_shift_stack(asm, argc); } + // Push a dynamic number of items from the splat array to the stack when calling a vargs method + let dynamic_splat_size = if variable_splat { + asm_comment!(asm, "variable length splat"); + let just_splat = usize::from(!kw_splat && kw_arg.is_null()).into(); + let stack_splat_array = asm.lea(asm.stack_opnd(0)); + Some(asm.ccall(rb_yjit_splat_varg_cfunc as _, vec![stack_splat_array, just_splat])) + } else { + None + }; + // Points to the receiver operand on the stack let recv = asm.stack_opnd(argc); // Store incremented PC into current control frame in case callee raises. jit_save_pc(jit, asm); - // Increment the stack pointer by 3 (in the callee) - // sp += 3 - let sp = asm.lea(asm.ctx.sp_opnd(SIZEOF_VALUE_I32 * 3)); + // Find callee's SP with space for metadata. + // Usually sp+3. + let sp = if let Some(splat_size) = dynamic_splat_size { + // Compute the callee's SP at runtime in case we accept a variable size for the splat array + const _: () = assert!(SIZEOF_VALUE == 8, "opting for a shift since mul on A64 takes no immediates"); + let splat_size_bytes = asm.lshift(splat_size, 3usize.into()); + // 3 items for method metadata, minus one to remove the splat array + let static_stack_top = asm.lea(asm.ctx.sp_opnd(SIZEOF_VALUE_I32 * 2)); + asm.add(static_stack_top, splat_size_bytes) + } else { + asm.lea(asm.ctx.sp_opnd(SIZEOF_VALUE_I32 * 3)) + }; let specval = if block_arg_type == Some(Type::BlockParamProxy) { SpecVal::BlockHandler(Some(BlockHandler::BlockParamProxy)) @@ -6382,8 +6408,17 @@ fn gen_send_cfunc( else if cfunc_argc == -1 { // The method gets a pointer to the first argument // rb_f_puts(int argc, VALUE *argv, VALUE recv) + + let passed_argc_opnd = if let Some(splat_size) = dynamic_splat_size { + // The final argc is the size of the splat, minus one for the splat array itself + asm.add(splat_size, (passed_argc - 1).into()) + } else { + // Without a splat, passed_argc is static + Opnd::Imm(passed_argc.into()) + }; + vec![ - Opnd::Imm(passed_argc.into()), + passed_argc_opnd, asm.lea(asm.ctx.sp_opnd(-argc * SIZEOF_VALUE_I32)), asm.stack_opnd(argc), ] diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 3ba351cea9dbd7..b131b62bfd1b1c 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -1194,6 +1194,15 @@ extern "C" { pub fn rb_yjit_fix_div_fix(recv: VALUE, obj: VALUE) -> VALUE; pub fn rb_yjit_fix_mod_fix(recv: VALUE, obj: VALUE) -> VALUE; pub fn rb_yjit_ruby2_keywords_splat_p(obj: VALUE) -> usize; + pub fn rb_yjit_splat_varg_checks( + sp: *mut VALUE, + splat_array: VALUE, + cfp: *mut rb_control_frame_t, + ) -> VALUE; + pub fn rb_yjit_splat_varg_cfunc( + stack_splat_array: *mut VALUE, + sole_splat: bool, + ) -> ::std::os::raw::c_int; pub fn rb_yjit_dump_iseq_loc(iseq: *const rb_iseq_t, insn_idx: u32); pub fn rb_yjit_iseq_inspect(iseq: *const rb_iseq_t) -> *mut ::std::os::raw::c_char; pub fn rb_FL_TEST(obj: VALUE, flags: VALUE) -> VALUE; diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs index 6cdf1d0616a2e4..d60e35b9dd22d6 100644 --- a/yjit/src/stats.rs +++ b/yjit/src/stats.rs @@ -386,7 +386,6 @@ make_counters! { send_args_splat_aref, send_args_splat_aset, send_args_splat_opt_call, - send_args_splat_cfunc_var_args, send_iseq_splat_arity_error, send_splat_too_long, send_send_wrong_args, @@ -444,6 +443,8 @@ make_counters! { guard_send_not_string, guard_send_respond_to_mid_mismatch, + guard_send_cfunc_bad_splat_vargs, + guard_invokesuper_me_changed, guard_invokeblock_tag_changed,