Skip to content

Commit

Permalink
YJIT: Avoid writing return value to memory in leave
Browse files Browse the repository at this point in the history
Previously, at the end of `leave` we did
`*caller_cfp->sp = return_value`, like the interpreter.
With future changes that leaves the SP field uninitialized for C frames,
this will become problematic. For cases like returning from
`rb_funcall()`, the return value was written above the stack never
read anyway (the copy in the return register is what callers use).

Leave the return value in a register at the end of `leave` and have the
code at `cfp->jit_return` decide what to do with it. For JIT-to-JIT
return, it goes through `asm.stack_push()` and benefits from register
allocation for stack temporaries.

Mostly flat on benchmarks, with maybe some marginal speed improvements.
  • Loading branch information
XrXr committed Oct 3, 2023
1 parent d47af93 commit edbef15
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 41 deletions.
6 changes: 6 additions & 0 deletions yjit/src/backend/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub const SP: Opnd = _SP;

pub const C_ARG_OPNDS: [Opnd; 6] = _C_ARG_OPNDS;
pub const C_RET_OPND: Opnd = _C_RET_OPND;
pub use crate::backend::current::{Reg, C_RET_REG};

// Memory operand base
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
Expand Down Expand Up @@ -955,6 +956,7 @@ pub struct SideExitContext {
pub stack_size: u8,
pub sp_offset: i8,
pub reg_temps: RegTemps,
pub is_return_landing: bool,
}

impl SideExitContext {
Expand All @@ -965,6 +967,7 @@ impl SideExitContext {
stack_size: ctx.get_stack_size(),
sp_offset: ctx.get_sp_offset(),
reg_temps: ctx.get_reg_temps(),
is_return_landing: ctx.is_return_landing(),
};
if cfg!(debug_assertions) {
// Assert that we're not losing any mandatory metadata
Expand All @@ -979,6 +982,9 @@ impl SideExitContext {
ctx.set_stack_size(self.stack_size);
ctx.set_sp_offset(self.sp_offset);
ctx.set_reg_temps(self.reg_temps);
if self.is_return_landing {
ctx.set_as_return_landing();
}
ctx
}
}
Expand Down
54 changes: 37 additions & 17 deletions yjit/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,12 @@ fn gen_exit(exit_pc: *mut VALUE, asm: &mut Assembler) {
asm_comment!(asm, "exit to interpreter on {}", insn_name(opcode as usize));
}

if asm.ctx.is_return_landing() {
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP));
let top = asm.stack_push(Type::Unknown);
asm.mov(top, C_RET_OPND);
}

// Spill stack temps before returning to the interpreter
asm.spill_temps();

Expand Down Expand Up @@ -636,13 +642,18 @@ fn gen_leave_exception(ocb: &mut OutlinedCb) -> CodePtr {
let code_ptr = ocb.get_write_ptr();
let mut asm = Assembler::new();

// gen_leave() leaves the return value in C_RET_OPND before coming here.
let ruby_ret_val = asm.live_reg_opnd(C_RET_OPND);

// Every exit to the interpreter should be counted
gen_counter_incr(&mut asm, Counter::leave_interp_return);

asm_comment!(asm, "increment SP of the caller");
let sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP);
asm_comment!(asm, "push return value through cfp->sp");
let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP);
let sp = asm.load(cfp_sp);
asm.mov(Opnd::mem(64, sp, 0), ruby_ret_val);
let new_sp = asm.add(sp, SIZEOF_VALUE.into());
asm.mov(sp, new_sp);
asm.mov(cfp_sp, new_sp);

asm_comment!(asm, "exit from exception");
asm.cpop_into(SP);
Expand Down Expand Up @@ -872,6 +883,18 @@ pub fn gen_single_block(
asm_comment!(asm, "reg_temps: {:08b}", asm.ctx.get_reg_temps().as_u8());
}

if asm.ctx.is_return_landing() {
// Continuation of the end of gen_leave().
// Reload REG_SP for the current frame and transfer the return value
// to the stack top.
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP));

let top = asm.stack_push(Type::Unknown);
asm.mov(top, C_RET_OPND);

asm.ctx.clear_return_landing();
}

// For each instruction to compile
// NOTE: could rewrite this loop with a std::iter::Iterator
while insn_idx < iseq_size {
Expand Down Expand Up @@ -6535,17 +6558,14 @@ fn gen_send_iseq(
// The callee might change locals through Kernel#binding and other means.
asm.ctx.clear_local_types();

// Pop arguments and receiver in return context, push the return value
// After the return, sp_offset will be 1. The codegen for leave writes
// the return value in case of JIT-to-JIT return.
// Pop arguments and receiver in return context and
// mark it as a continuation of gen_leave()
let mut return_asm = Assembler::new();
return_asm.ctx = asm.ctx.clone();
return_asm.stack_pop(sp_offset.try_into().unwrap());
let return_val = return_asm.stack_push(Type::Unknown);
// The callee writes a return value on stack. Update reg_temps accordingly.
return_asm.ctx.dealloc_temp_reg(return_val.stack_idx());
return_asm.ctx.set_sp_offset(1);
return_asm.ctx.set_sp_offset(0); // We set SP on the caller's frame above
return_asm.ctx.reset_chain_depth();
return_asm.ctx.set_as_return_landing();

// Write the JIT return address on the callee frame
gen_branch(
Expand Down Expand Up @@ -7745,15 +7765,15 @@ fn gen_leave(
// Load the return value
let retval_opnd = asm.stack_pop(1);

// Move the return value into the C return register for gen_leave_exit()
// Move the return value into the C return register
asm.mov(C_RET_OPND, retval_opnd);

// Reload REG_SP for the caller and write the return value.
// Top of the stack is REG_SP[0] since the caller has sp_offset=1.
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP));
asm.mov(Opnd::mem(64, SP, 0), C_RET_OPND);

// Jump to the JIT return address on the frame that was just popped
// Jump to the JIT return address on the frame that was just popped.
// There are a few possible jump targets:
// - gen_leave_exit() and gen_leave_exception(), for C callers
// - Return context set up by gen_send_iseq()
// We don't write the return value to stack memory like the interpreter here.
// Each jump target do it as necessary.
let offset_to_jit_return =
-(RUBY_SIZEOF_CONTROL_FRAME as i32) + RUBY_OFFSET_CFP_JIT_RETURN;
asm.jmp_opnd(Opnd::mem(64, CFP, offset_to_jit_return));
Expand Down
94 changes: 70 additions & 24 deletions yjit/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,11 @@ pub struct Context {
/// Bitmap of which stack temps are in a register
reg_temps: RegTemps,

// Depth of this block in the sidechain (eg: inline-cache chain)
chain_depth: u8,
/// Fields packed into u8
/// - Lower 7 bits: Depth of this block in the sidechain (eg: inline-cache chain)
/// - Top bit: Whether this code is the target of a JIT-to-JIT Ruby return
/// ([Self::is_return_landing])
chain_depth_plus: u8,

// Local variable types we keep track of
local_types: [Type; MAX_LOCAL_TYPES],
Expand Down Expand Up @@ -1399,7 +1402,7 @@ fn find_block_version(blockid: BlockId, ctx: &Context) -> Option<BlockRef> {
/// Produce a generic context when the block version limit is hit for a blockid
pub fn limit_block_versions(blockid: BlockId, ctx: &Context) -> Context {
// Guard chains implement limits separately, do nothing
if ctx.chain_depth > 0 {
if ctx.get_chain_depth() > 0 {
return ctx.clone();
}

Expand Down Expand Up @@ -1607,6 +1610,9 @@ impl Context {
generic_ctx.stack_size = self.stack_size;
generic_ctx.sp_offset = self.sp_offset;
generic_ctx.reg_temps = self.reg_temps;
if self.is_return_landing() {
generic_ctx.set_as_return_landing();
}
generic_ctx
}

Expand Down Expand Up @@ -1637,15 +1643,30 @@ impl Context {
}

pub fn get_chain_depth(&self) -> u8 {
self.chain_depth
self.chain_depth_plus & 0x7f
}

pub fn reset_chain_depth(&mut self) {
self.chain_depth = 0;
self.chain_depth_plus &= 0x80;
}

pub fn increment_chain_depth(&mut self) {
self.chain_depth += 1;
if self.get_chain_depth() == 0x7f {
panic!("max block version chain depth reached!");
}
self.chain_depth_plus += 1;
}

pub fn set_as_return_landing(&mut self) {
self.chain_depth_plus |= 0x80;
}

pub fn clear_return_landing(&mut self) {
self.chain_depth_plus &= 0x7f;
}

pub fn is_return_landing(&self) -> bool {
self.chain_depth_plus & 0x80 > 0
}

/// Get an operand for the adjusted stack pointer address
Expand Down Expand Up @@ -1842,13 +1863,17 @@ impl Context {
let src = self;

// Can only lookup the first version in the chain
if dst.chain_depth != 0 {
if dst.get_chain_depth() != 0 {
return TypeDiff::Incompatible;
}

// Blocks with depth > 0 always produce new versions
// Sidechains cannot overlap
if src.chain_depth != 0 {
if src.get_chain_depth() != 0 {
return TypeDiff::Incompatible;
}

if src.is_return_landing() != dst.is_return_landing() {
return TypeDiff::Incompatible;
}

Expand Down Expand Up @@ -2495,6 +2520,9 @@ fn branch_stub_hit_body(branch_ptr: *const c_void, target_idx: u32, ec: EcPtr) -
let running_iseq = rb_cfp_get_iseq(cfp);
let reconned_pc = rb_iseq_pc_at_idx(running_iseq, target_blockid.idx.into());
let reconned_sp = original_interp_sp.offset(target_ctx.sp_offset.into());
// Unlike in the interpreter, our `leave` doesn't write to the caller's
// SP -- we do it in the returned-to code. Account for this difference.
let reconned_sp = reconned_sp.add(target_ctx.is_return_landing().into());

assert_eq!(running_iseq, target_blockid.iseq as _, "each stub expects a particular iseq");

Expand Down Expand Up @@ -2631,10 +2659,16 @@ fn gen_branch_stub(
asm.set_reg_temps(ctx.reg_temps);
asm_comment!(asm, "branch stub hit");

if asm.ctx.is_return_landing() {
asm.mov(SP, Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP));
let top = asm.stack_push(Type::Unknown);
asm.mov(top, C_RET_OPND);
}

// Save caller-saved registers before C_ARG_OPNDS get clobbered.
// Spill all registers for consistency with the trampoline.
for &reg in caller_saved_temp_regs().iter() {
asm.cpush(reg);
for &reg in caller_saved_temp_regs() {
asm.cpush(Opnd::Reg(reg));
}

// Spill temps to the VM stack as well for jit.peek_at_stack()
Expand Down Expand Up @@ -2675,36 +2709,51 @@ pub fn gen_branch_stub_hit_trampoline(ocb: &mut OutlinedCb) -> CodePtr {
// Since this trampoline is static, it allows code GC inside
// branch_stub_hit() to free stubs without problems.
asm_comment!(asm, "branch_stub_hit() trampoline");
let jump_addr = asm.ccall(
let stub_hit_ret = asm.ccall(
branch_stub_hit as *mut u8,
vec![
C_ARG_OPNDS[0],
C_ARG_OPNDS[1],
EC,
]
);
let jump_addr = asm.load(stub_hit_ret);

// Restore caller-saved registers for stack temps
for &reg in caller_saved_temp_regs().iter().rev() {
asm.cpop_into(reg);
for &reg in caller_saved_temp_regs().rev() {
asm.cpop_into(Opnd::Reg(reg));
}

// Jump to the address returned by the branch_stub_hit() call
asm.jmp_opnd(jump_addr);

// HACK: restoring the value of C_RET_REG clobbers the
// return value of branch_stub_hit we want to jump to,
// so we need a scratch regsiter. This unrechable `test`
// maintain liveness of C return reg so we get something
// else for the scratch register.
asm.test(stub_hit_ret, 0.into());

asm.compile(ocb, None);

code_ptr
}

/// Return registers to be pushed and popped on branch_stub_hit.
/// The return value may include an extra register for x86 alignment.
fn caller_saved_temp_regs() -> Vec<Opnd> {
let mut regs = Assembler::get_temp_regs().to_vec();
if regs.len() % 2 == 1 {
regs.push(*regs.last().unwrap()); // x86 alignment
fn caller_saved_temp_regs() -> impl Iterator<Item = &'static Reg> + DoubleEndedIterator {
let temp_regs = Assembler::get_temp_regs().iter();
let len = temp_regs.len();
// The return value gen_leave() leaves in C_RET_REG
// needs to survive the branch_stub_hit() call.
let regs = temp_regs.chain(std::iter::once(&C_RET_REG));

// On x86_64, maintain 16-byte stack alignment
if cfg!(target_arch = "x86_64") && len % 2 == 0 {
static ONE_MORE: [Reg; 1] = [C_RET_REG];
regs.chain(ONE_MORE.iter())
} else {
regs.chain(&[])
}
regs.iter().map(|&reg| Opnd::Reg(reg)).collect()
}

impl Assembler
Expand Down Expand Up @@ -2831,16 +2880,13 @@ pub fn defer_compilation(
asm: &mut Assembler,
ocb: &mut OutlinedCb,
) {
if asm.ctx.chain_depth != 0 {
if asm.ctx.get_chain_depth() != 0 {
panic!("Double defer!");
}

let mut next_ctx = asm.ctx.clone();

if next_ctx.chain_depth == u8::MAX {
panic!("max block version chain depth reached!");
}
next_ctx.chain_depth += 1;
next_ctx.increment_chain_depth();

let branch = new_pending_branch(jit, BranchGenFn::JumpToTarget0(Cell::new(BranchShape::Default)));

Expand Down

0 comments on commit edbef15

Please sign in to comment.