Skip to content

Commit

Permalink
YJIT: Support **nil
Browse files Browse the repository at this point in the history
This adds YJIT support for VM_CALL_KW_SPLAT with nil, specifically for
when we already know from the context that it's done with a nil. This is
enough to support forwarding with `...` when there no keyword arguments
are present.

Amend the kw_rest support to propagate the type of the parameter to help
with this. Test interactions with splat, since the splat array sits
lower on the stack when a kw_splat argument is present.
  • Loading branch information
XrXr committed Feb 16, 2024
1 parent 1b9b960 commit 80aee1a
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 22 deletions.
27 changes: 27 additions & 0 deletions bootstraptest/test_yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4549,3 +4549,30 @@ def callsite = rest(complete: false)
callsite
}

# splat+kw_splat+opt+rest
assert_equal '[1, []]', %q{
def opt_rest(a = 0, *rest) = [a, rest]
def call_site(args) = opt_rest(*args, **nil)
call_site([1])
}

# splat+kw_splat+opt+rest
assert_equal '[1, []]', %q{
def opt_rest(a = 0, *rest) = [a, rest]
def call_site(args) = opt_rest(*args, **nil)
call_site([1])
}

# splat and nil kw_splat
assert_equal 'ok', %q{
def identity(x) = x
def splat_nil_kw_splat(args) = identity(*args, **nil)
splat_nil_kw_splat([:ok])
}
11 changes: 11 additions & 0 deletions test/ruby/test_yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,17 @@ def test_send_polymorphic_method_name
RUBY
end

def test_kw_splat_nil
assert_compiles(<<~'RUBY', result: %i[ok ok], no_send_fallbacks: true)
def id(x) = x
def kw_fw(arg, **) = id(arg, **)
def fw(...) = id(...)
def use = [fw(:ok), kw_fw(:ok)]
use
RUBY
end

private

def code_gc_helpers
Expand Down
65 changes: 44 additions & 21 deletions yjit/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6037,6 +6037,12 @@ fn gen_send_cfunc(
return None;
}

// Don't JIT calls with keyword splat
if flags & VM_CALL_KW_SPLAT != 0 {
gen_counter_incr(asm, Counter::send_kw_splat);
return None;
}

let kw_arg = unsafe { vm_ci_kwarg(ci) };
let kw_arg_num = if kw_arg.is_null() {
0
Expand Down Expand Up @@ -6515,6 +6521,7 @@ fn gen_send_iseq(
let iseq_has_rest = unsafe { get_iseq_flags_has_rest(iseq) };
let iseq_has_block_param = unsafe { get_iseq_flags_has_block(iseq) };
let arg_setup_block = captured_opnd.is_some(); // arg_setup_type: arg_setup_block (invokeblock)
let kw_splat = flags & VM_CALL_KW_SPLAT != 0;

// For computing offsets to callee locals
let num_params = unsafe { get_iseq_body_param_size(iseq) as i32 };
Expand All @@ -6533,7 +6540,7 @@ fn gen_send_iseq(
};

// Arity handling and optional parameter setup
let mut opts_filled = argc - required_num - kw_arg_num;
let mut opts_filled = argc - required_num - kw_arg_num - i32::from(kw_splat);
let opt_num = unsafe { get_iseq_body_param_opt_num(iseq) };
// With a rest parameter or a yield to a block,
// callers can pass more than required + optional.
Expand All @@ -6544,11 +6551,13 @@ fn gen_send_iseq(
let mut opts_missing: i32 = opt_num - opts_filled;

let block_arg = flags & VM_CALL_ARGS_BLOCKARG != 0;
// Stack index of the splat array
let splat_pos = i32::from(block_arg) + i32::from(kw_splat) + kw_arg_num;

exit_if_stack_too_large(iseq)?;
exit_if_tail_call(asm, ci)?;
exit_if_has_post(asm, iseq)?;
exit_if_kw_splat(asm, flags)?;
exit_if_kwsplat_non_nil(asm, flags)?;
exit_if_has_rest_and_captured(asm, iseq_has_rest, captured_opnd)?;
exit_if_has_kwrest_and_captured(asm, has_kwrest, captured_opnd)?;
exit_if_has_rest_and_supplying_kws(asm, iseq_has_rest, supplying_kws)?;
Expand Down Expand Up @@ -6586,7 +6595,7 @@ fn gen_send_iseq(
}

let splat_array_length = if flags & VM_CALL_ARGS_SPLAT != 0 {
let array = jit.peek_at_stack(&asm.ctx, if block_arg { 1 } else { 0 }) ;
let array = jit.peek_at_stack(&asm.ctx, splat_pos as isize);
let array_length = if array == Qnil {
0
} else if unsafe { !RB_TYPE_P(array, RUBY_T_ARRAY) } {
Expand All @@ -6599,7 +6608,7 @@ fn gen_send_iseq(
// Arity check accounting for size of the splat. When callee has rest parameters, we insert
// runtime guards later in copy_splat_args_for_rest_callee()
if !iseq_has_rest {
let supplying = argc - 1 + array_length as i32;
let supplying = argc - 1 - i32::from(kw_splat) + array_length as i32;
if (required_num..=required_num + opt_num).contains(&supplying) == false {
gen_counter_incr(asm, Counter::send_iseq_splat_arity_error);
return None;
Expand All @@ -6615,7 +6624,7 @@ fn gen_send_iseq(
// On a normal splat without rest and option args this is handled
// elsewhere depending on the case
asm_comment!(asm, "Side exit if length doesn't not equal compile time length");
let array_len_opnd = get_array_len(asm, asm.stack_opnd(if block_arg { 1 } else { 0 }));
let array_len_opnd = get_array_len(asm, asm.stack_opnd(splat_pos));
asm.cmp(array_len_opnd, array_length.into());
asm.jne(Target::side_exit(Counter::guard_send_splatarray_length_not_equal));
}
Expand Down Expand Up @@ -6657,7 +6666,7 @@ fn gen_send_iseq(
if let Some(len) = splat_array_length {
assert_eq!(kw_arg_num, 0); // Due to exit_if_doing_kw_and_splat().
// Simplifies calculation below.
let num_args = (argc - 1) + len as i32;
let num_args = argc - 1 - i32::from(kw_splat) + len as i32;

opts_filled = if num_args >= required_num {
min(num_args - required_num, opt_num)
Expand Down Expand Up @@ -6764,14 +6773,12 @@ fn gen_send_iseq(
asm.jbe(Target::side_exit(Counter::guard_send_se_cf_overflow));

if iseq_has_rest && flags & VM_CALL_ARGS_SPLAT != 0 {
let splat_pos = i32::from(block_arg) + kw_arg_num;

// Insert length guard for a call to copy_splat_args_for_rest_callee()
// that will come later. We will have made changes to
// the stack by spilling or handling __send__ shifting
// by the time we get to that code, so we need the
// guard here where we can still side exit.
let non_rest_arg_count = argc - 1;
let non_rest_arg_count = argc - i32::from(kw_splat) - 1;
if non_rest_arg_count < required_num + opt_num {
let take_count: u32 = (required_num - non_rest_arg_count + opts_filled)
.try_into().unwrap();
Expand Down Expand Up @@ -6839,6 +6846,13 @@ fn gen_send_iseq(
_ => unreachable!(),
}

if kw_splat {
// Only `**nil` is supported right now. Checked in exit_if_kwsplat_non_nil()
assert_eq!(Type::Nil, asm.ctx.get_opnd_type(StackOpnd(0)));
asm.stack_pop(1);
argc -= 1;
}

// push_splat_args does stack manipulation so we can no longer side exit
if let Some(array_length) = splat_array_length {
if !iseq_has_rest {
Expand Down Expand Up @@ -7329,7 +7343,7 @@ fn gen_iseq_kw_call(

// Build the keyword rest parameter hash before we make any changes to the order of
// the supplied keyword arguments
if has_kwrest {
let kwrest_type = if has_kwrest {
c_callable! {
fn build_kw_rest(rest_mask: u64, stack_kwargs: *const VALUE, keywords: *const rb_callinfo_kwarg) -> VALUE {
if keywords.is_null() {
Expand Down Expand Up @@ -7419,7 +7433,11 @@ fn gen_iseq_kw_call(
if stack_kwrest_idx >= 0 {
asm.ctx.set_opnd_mapping(stack_kwrest.into(), TempMapping::map_to_stack(kwrest_type));
}
}

Some(kwrest_type)
} else {
None
};

// Ensure the stack is large enough for the callee
for _ in caller_keyword_len..callee_kw_count {
Expand Down Expand Up @@ -7499,8 +7517,17 @@ fn gen_iseq_kw_call(
// explicitly given a value and have a non-constant default.
if callee_kw_count > 0 {
let unspec_opnd = VALUE::fixnum_from_usize(unspecified_bits).as_u64();
asm.ctx.dealloc_temp_reg(asm.stack_opnd(-1).stack_idx()); // avoid using a register for unspecified_bits
asm.mov(asm.stack_opnd(-1), unspec_opnd.into());
let top = asm.stack_push(Type::Fixnum);
asm.mov(top, unspec_opnd.into());
argc += 1;
}

// The kwrest parameter sits after `unspecified_bits`
if let Some(kwrest_type) = kwrest_type {
let kwrest = asm.stack_push(kwrest_type);
// We put the kwrest parameter in memory earlier
asm.ctx.dealloc_temp_reg(kwrest.stack_idx());
argc += 1;
}

argc
Expand Down Expand Up @@ -7531,8 +7558,10 @@ fn exit_if_has_post(asm: &mut Assembler, iseq: *const rb_iseq_t) -> Option<()> {
}

#[must_use]
fn exit_if_kw_splat(asm: &mut Assembler, flags: u32) -> Option<()> {
exit_if(asm, flags & VM_CALL_KW_SPLAT != 0, Counter::send_iseq_kw_splat)
fn exit_if_kwsplat_non_nil(asm: &mut Assembler, flags: u32) -> Option<()> {
let kw_splat = flags & VM_CALL_KW_SPLAT != 0;
let kw_splat_stack = StackOpnd((flags & VM_CALL_ARGS_BLOCKARG != 0).into());
exit_if(asm, kw_splat && asm.ctx.get_opnd_type(kw_splat_stack) != Type::Nil, Counter::send_iseq_kw_splat_non_nil)
}

#[must_use]
Expand Down Expand Up @@ -7829,12 +7858,6 @@ fn gen_send_general(
let mut mid = unsafe { vm_ci_mid(ci) };
let mut flags = unsafe { vm_ci_flag(ci) };

// Don't JIT calls with keyword splat
if flags & VM_CALL_KW_SPLAT != 0 {
gen_counter_incr(asm, Counter::send_kw_splat);
return None;
}

// Defer compilation so we can specialize on class of receiver
if !jit.at_current_insn() {
defer_compilation(jit, asm, ocb);
Expand Down
2 changes: 1 addition & 1 deletion yjit/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ make_counters! {
send_iseq_complex_discard_extras,
send_iseq_leaf_builtin_block_arg_block_param,
send_iseq_only_keywords,
send_iseq_kw_splat,
send_iseq_kw_splat_non_nil,
send_iseq_kwargs_req_and_opt_missing,
send_iseq_kwargs_mismatch,
send_iseq_has_post,
Expand Down

0 comments on commit 80aee1a

Please sign in to comment.