Skip to content

Commit

Permalink
YJIT: Fuse opt_eq, opt_lt, and branchunless for fixnums
Browse files Browse the repository at this point in the history
This targets the instruction pair that usually result from code like
`if a == b` or `if a < b`.

Fusing the two makes it so we don't need to use csel and then test the
result:

```diff
-# Insn: 0007 opt_eq
+# Insn: 0007 opt_eq (stack_size: 3)
 cmp x9, x10
-mov x11, #0x14
-mov x12, #0
-csel x11, x11, x12, eq
-mov x9, x11
-# Insn: 0009 branchunless
-tst x9, #-5
-b.eq #0x1092ee174
+b.ne #0x10457a174
 nop
-b #0x1092ee1a8
+# gen_direct_jmp: fallthrough
```

By using gen_direct_jmp() for when the branch is not taken, the fusion
also reduces one stub for each fused branch, saving some code size.
  • Loading branch information
XrXr committed Apr 9, 2024
1 parent 5891878 commit ca761b5
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 2 deletions.
1 change: 1 addition & 0 deletions yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def _print_stats(out: $stderr) # :nodoc:

out.puts "branch_insn_count: " + format_number(13, stats[:branch_insn_count])
out.puts "branch_known_count: " + format_number_pct(13, stats[:branch_known_count], stats[:branch_insn_count])
out.puts "branch_fused_count: " + format_number_pct(13, stats[:branch_fused_count], stats[:branch_insn_count])

out.puts "freed_iseq_count: " + format_number(13, stats[:freed_iseq_count])
out.puts "invalidation_count: " + format_number(13, stats[:invalidation_count])
Expand Down
97 changes: 95 additions & 2 deletions yjit/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,28 @@ impl JITState {
unsafe { *(self.pc.offset(arg_idx + 1)) }
}

// Get an argument of the next instruction
pub fn next_insn_arg(&self, arg_idx: usize) -> VALUE {
let current_len = insn_len(self.get_opcode()).as_usize();
unsafe {
let next_pc = self.pc.add(current_len);
let next_opcode = rb_iseq_opcode_at_pc(self.iseq, next_pc);
assert!(insn_len(next_opcode as usize) as usize > (arg_idx + 1));
next_pc.add(arg_idx + 1).read()
}
}

// Get the index of the next instruction
fn next_insn_idx(&self) -> u16 {
self.insn_idx + insn_len(self.get_opcode()) as u16
}

// Get the opcode of the next instruction
pub fn next_insn_opcode(&self) -> ruby_vminsn_type {
let next_idx = self.next_insn_idx();
iseq_opcode_at_idx(self.iseq, next_idx.into())
}

// Check if we are compiling the instruction at the stub PC
// Meaning we are compiling the instruction that is next to execute
pub fn at_current_insn(&self) -> bool {
Expand Down Expand Up @@ -1392,8 +1409,7 @@ fn fuse_putobject_opt_ltlt(
constant_object: VALUE,
ocb: &mut OutlinedCb,
) -> Option<CodegenStatus> {
let next_opcode = unsafe { rb_vm_insn_addr2opcode(jit.pc.add(insn_len(jit.opcode).as_usize()).read().as_ptr()) };
if next_opcode == YARVINSN_opt_ltlt as i32 && constant_object.fixnum_p() {
if jit.next_insn_opcode() == YARVINSN_opt_ltlt && constant_object.fixnum_p() {
// Untag the fixnum shift amount
let shift_amt = constant_object.as_isize() >> 1;
if shift_amt > 63 || shift_amt < 0 {
Expand Down Expand Up @@ -3372,6 +3388,9 @@ fn gen_opt_lt(
asm: &mut Assembler,
ocb: &mut OutlinedCb,
) -> Option<CodegenStatus> {
if let Some(result) = fuse_fixnum_cmp_branchunless(jit, asm, ocb, BranchGenFn::JGEToTarget0) {
return Some(result);
}
gen_fixnum_cmp(jit, asm, ocb, Assembler::csel_l, BOP_LT)
}

Expand Down Expand Up @@ -3515,11 +3534,85 @@ fn gen_equality_specialized(
}
}

/// Attempt to generate code for a compare and branch instruction pair for fixnums, usually coming
/// from code like `if a == b`. The caller specifies the condition for when the `branchunless`
/// should be taken through `branch_fn`.
fn fuse_fixnum_cmp_branchunless(
jit: &mut JITState,
asm: &mut Assembler,
ocb: &mut OutlinedCb,
branch_fn: BranchGenFn,
) -> Option<CodegenStatus> {
if jit.next_insn_opcode() != YARVINSN_branchunless {
return None;
}

// Only fuse when there are two fixnums on the stack
match asm.ctx.two_fixnums_on_stack(jit) {
Some(true) => (),
Some(false) => return None,
None => {
// Defer compilation so we can peek at the stack
defer_compilation(jit, asm, ocb);
return Some(EndBlock);
}
};

// Give up if integer equality is overridden
if !assume_bop_not_redefined(jit, asm, ocb, INTEGER_REDEFINED_OP_FLAG, BOP_EQ) {
return None;
}

// Check that both operands are fixnums
guard_two_fixnums(jit, asm, ocb);

let jump_offset = jit.next_insn_arg(0).as_i32();

// Check for interrupts, but only on backward branches that may create loops
if jump_offset < 0 {
gen_check_ints(asm, Counter::branchunless_interrupted);
}

// Get the branch target instruction offsets
let next_idx = jit.next_insn_idx() + insn_len(YARVINSN_branchunless.as_usize()) as u16;
let jump_idx = (next_idx as i32) + jump_offset;
let next_block = BlockId {
iseq: jit.iseq,
idx: next_idx,
};
let jump_block = BlockId {
iseq: jit.iseq,
idx: jump_idx.try_into().unwrap(),
};

// Get the operands from the stack
let arg1 = asm.stack_pop(1);
let arg0 = asm.stack_pop(1);

incr_counter!(branch_insn_count);
incr_counter!(branch_fused_count);

// Compare the arguments
asm.cmp(arg0, arg1);

// Generate the branch instructions
let mut ctx = asm.ctx;
ctx.reset_chain_depth_and_defer();
gen_branch(jit, asm, ocb, jump_block, &ctx, None, None, branch_fn);

gen_direct_jump(jit, &ctx, next_block, asm);
Some(EndBlock)
}

fn gen_opt_eq(
jit: &mut JITState,
asm: &mut Assembler,
ocb: &mut OutlinedCb,
) -> Option<CodegenStatus> {
if let Some(result) = fuse_fixnum_cmp_branchunless(jit, asm, ocb, BranchGenFn::JNZToTarget0) {
return Some(result);
}

let specialized = match gen_equality_specialized(jit, asm, ocb, true) {
Some(specialized) => specialized,
None => {
Expand Down
4 changes: 4 additions & 0 deletions yjit/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ pub enum BranchGenFn {
JBEToTarget0,
JBToTarget0,
JOMulToTarget0,
JGEToTarget0,
JITReturn,
}

Expand Down Expand Up @@ -573,6 +574,7 @@ impl BranchGenFn {
asm.jmp(target0);
}
}
BranchGenFn::JGEToTarget0 => asm.jge(target0),
BranchGenFn::JNZToTarget0 => {
asm.jnz(target0)
}
Expand Down Expand Up @@ -608,6 +610,7 @@ impl BranchGenFn {
BranchGenFn::JBEToTarget0 |
BranchGenFn::JBToTarget0 |
BranchGenFn::JOMulToTarget0 |
BranchGenFn::JGEToTarget0 |
BranchGenFn::JITReturn => BranchShape::Default,
}
}
Expand All @@ -630,6 +633,7 @@ impl BranchGenFn {
BranchGenFn::JBEToTarget0 |
BranchGenFn::JBToTarget0 |
BranchGenFn::JOMulToTarget0 |
BranchGenFn::JGEToTarget0 |
BranchGenFn::JITReturn => {
assert_eq!(new_shape, BranchShape::Default);
}
Expand Down
1 change: 1 addition & 0 deletions yjit/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ make_counters! {
defer_empty_count,
branch_insn_count,
branch_known_count,
branch_fused_count,
max_inline_versions,

freed_iseq_count,
Expand Down

0 comments on commit ca761b5

Please sign in to comment.