From 4faa5997d31fd82c1b9736e932080d718fd11424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aurimas=20Bla=C5=BEulionis?= <0x60@pm.me> Date: Fri, 27 Oct 2023 14:28:39 +0100 Subject: [PATCH] Rework packet waker atomic ops to be less error prone to implement Add test-asan workflow --- .github/workflows/build.yml | 32 +++++++ mfio/src/io/mod.rs | 16 ++-- mfio/src/io/packet/mod.rs | 175 ++++++++++++++++++++++++------------ mfio/src/io/packet/view.rs | 4 +- 4 files changed, 161 insertions(+), 66 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e7668dd..b75e3d2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -115,6 +115,38 @@ jobs: - name: Run all tests run: cargo test --workspace --all-features --verbose + test-asan: + runs-on: ${{ matrix.os }} + env: + RUSTFLAGS: -Zsanitizer=address -C debuginfo=2 ${{ matrix.rustflags }} + RUSTDOCFLAGS: -Zsanitizer=address -C debuginfo=2 ${{ matrix.rustflags }} + CARGO_BUILD_RUSTFLAGS: -C debuginfo=2 + ASAN_OPTIONS: symbolize=1 detect_leaks=0 + timeout-minutes: 20 + strategy: + fail-fast: false + matrix: + # TODO: enable windows, macos + os: [ubuntu-latest] + toolchain: ["nightly-2023-09-01"] + rustflags: ["--cfg mfio_assume_linear_types --cfg tokio_unstable", "--cfg tokio_unstable"] + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.toolchain }} + override: true + + - name: Get rustc target + run: | + echo "RUSTC_TARGET=$(rustc -vV | sed -n 's|host: ||p')" >> $GITHUB_OUTPUT + id: target + - name: Install llvm + run: sudo apt update && sudo apt install llvm-13 + - run: rustup component add rust-src + - name: Run all tests + run: cargo -Zbuild-std test --verbose --target ${{ steps.target.outputs.RUSTC_TARGET }} + lint: runs-on: ${{ matrix.os }} env: diff --git a/mfio/src/io/mod.rs b/mfio/src/io/mod.rs index f62b5a2..b6f3be8 100644 --- a/mfio/src/io/mod.rs +++ b/mfio/src/io/mod.rs @@ -102,7 +102,7 @@ pub trait PacketIoExt: PacketIo { //IoFut::NewId(self, param, packet.stack()) IoFut { pkt: UnsafeCell::new(Some(packet.stack())), - initial_state: Some((self, param)), + initial_state: UnsafeCell::new(Some((self, param))), _phantom: PhantomData, } } @@ -116,7 +116,7 @@ pub trait PacketIoExt: PacketIo { //IoFut::NewId(self, param, packet.stack()) IoToFut { pkt_out: UnsafeCell::new(Some((packet.stack(), output.stack()))), - initial_state: Some((self, param)), + initial_state: UnsafeCell::new(Some((self, param))), _phantom: PhantomData, } } @@ -177,7 +177,7 @@ impl NoPos { pub struct IoFut<'a, T, Perms: PacketPerms, Param, Packet: PacketStore<'a, Perms>> { pkt: UnsafeCell>>, - initial_state: Option<(&'a T, Param)>, + initial_state: UnsafeCell>, _phantom: PhantomData, } @@ -187,10 +187,10 @@ impl<'a, T: PacketIo, Perms: PacketPerms, Param, Pkt: PacketStore< type Output = Pkt::StackReq<'a>; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let state = unsafe { self.get_unchecked_mut() }; + let state: &Self = unsafe { core::mem::transmute(self) }; loop { - match state.initial_state.take() { + match unsafe { (*state.initial_state.get()).take() } { Some((io, param)) => { // SAFETY: this packet's existence is tied to 'a lifetime, meaning it will be valid // throughout 'a. @@ -230,7 +230,7 @@ pub struct IoToFut< Output: OutputStore<'a, Perms>, > { pkt_out: UnsafeCell, Output::StackReq<'a>)>>, - initial_state: Option<(&'a T, Param)>, + initial_state: UnsafeCell>, _phantom: PhantomData, } @@ -244,9 +244,9 @@ impl< > IoToFut<'a, T, Perms, Param, Pkt, Out> { pub fn submit(self: Pin<&mut Self>) -> &Out::StackReq<'a> { - let state = unsafe { self.get_unchecked_mut() }; + let state: &Self = unsafe { core::mem::transmute(self) }; - if let Some((io, param)) = state.initial_state.take() { + if let Some((io, param)) = unsafe { (*state.initial_state.get()).take() } { // SAFETY: this packet's existence is tied to 'a lifetime, meaning it will be valid // throughout 'a. let (pkt, out): &'a mut (Pkt::StackReq<'a>, Out::StackReq<'a>) = diff --git a/mfio/src/io/packet/mod.rs b/mfio/src/io/packet/mod.rs index 6987f7c..e11d7c2 100644 --- a/mfio/src/io/packet/mod.rs +++ b/mfio/src/io/packet/mod.rs @@ -20,6 +20,102 @@ pub use output::*; mod view; pub use view::*; +const LOCK_BIT: u64 = 1 << 63; +const HAS_WAKER_BIT: u64 = 1 << 62; +const FINALIZED_BIT: u64 = 1 << 61; +const ALL_BITS: u64 = LOCK_BIT | HAS_WAKER_BIT | FINALIZED_BIT; + +struct RcAndWaker { + rc_and_flags: AtomicU64, + waker: UnsafeCell>, +} + +impl Default for RcAndWaker { + fn default() -> Self { + Self { + rc_and_flags: 0.into(), + waker: UnsafeCell::new(MaybeUninit::uninit()), + } + } +} + +impl core::fmt::Debug for RcAndWaker { + fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result { + write!( + fmt, + "{}", + (self.rc_and_flags.load(Ordering::Relaxed) & HAS_WAKER_BIT) != 0 + ) + } +} + +impl RcAndWaker { + fn acquire(&self) -> bool { + (loop { + let flags = self.rc_and_flags.fetch_or(LOCK_BIT, Ordering::AcqRel); + if (flags & LOCK_BIT) == 0 { + break flags; + } + while self.rc_and_flags.load(Ordering::Relaxed) & LOCK_BIT != 0 { + core::hint::spin_loop(); + } + } & HAS_WAKER_BIT) + != 0 + } + + pub fn take(&self) -> Option { + let ret = if self.acquire() { + Some(unsafe { (*self.waker.get()).assume_init_read() }) + } else { + None + }; + self.rc_and_flags + .fetch_and(!(LOCK_BIT | HAS_WAKER_BIT), Ordering::Release); + ret + } + + pub fn write(&self, waker: CWaker) -> u64 { + if self.acquire() { + unsafe { core::ptr::drop_in_place((*self.waker.get()).as_mut_ptr()) } + } + + unsafe { *self.waker.get() = MaybeUninit::new(waker) }; + + self.rc_and_flags.fetch_or(HAS_WAKER_BIT, Ordering::Relaxed); + self.rc_and_flags.fetch_and(!LOCK_BIT, Ordering::AcqRel) & !ALL_BITS + } + + pub fn acquire_rc(&self) -> u64 { + self.rc_and_flags.load(Ordering::Acquire) & !ALL_BITS + } + + pub fn dec_rc(&self) -> (u64, bool) { + let ret = self.rc_and_flags.fetch_sub(1, Ordering::AcqRel); + (ret & !ALL_BITS, (ret & HAS_WAKER_BIT) != 0) + } + + pub fn inc_rc(&self) -> u64 { + self.rc_and_flags.fetch_add(1, Ordering::AcqRel) & !ALL_BITS + } + + pub fn finalize(&self) { + self.rc_and_flags.fetch_or(FINALIZED_BIT, Ordering::Release); + } + + pub fn wait_finalize(&self) { + // FIXME: in theory, wait_finalize should only wait for the FINALIZED_BIT, but not deal + // with the locking and the waker. However, something is making us have to take the waker, + // to make these atomic ops sound (however, even then I doubt this is fully sound, but is + // merely moving probability of desync lower). + // Either way, we should be able to have this waker mechanism be way more optimized, + // without atomic locks. + self.take(); + while (self.rc_and_flags.load(Ordering::Acquire) & FINALIZED_BIT) == 0 { + core::hint::spin_loop(); + } + } +} + /// Describes a full packet. /// /// This packet is considered simple. @@ -548,9 +644,7 @@ pub struct Packet { /// /// return true /// ``` - rc_and_flags: AtomicUsize, - /// Waker to be triggered, upon `rc` dropping down to 0. - waker: UnsafeCell>, + rc_and_waker: RcAndWaker, /// What was the smallest position that resulted in an error. /// /// This value is initialized to !0, and upon each errored packet segment, is minned @@ -573,17 +667,8 @@ unsafe impl Sync for Packet {} impl Drop for Packet { fn drop(&mut self) { - let loaded = self.rc_and_flags.load(Ordering::Acquire); - assert_eq!( - loaded & !(0b11 << 62), - 0, - "The packet has in-flight segments." - ); - if loaded >> 62 == 0b11 { - unsafe { - core::ptr::drop_in_place(self.waker.get_mut().as_mut_ptr()); - } - } + let loaded = self.rc_and_waker.acquire_rc(); + assert_eq!(loaded, 0, "The packet has in-flight segments."); } } @@ -593,38 +678,18 @@ impl<'a, Perms: PacketPerms> Future for &'a Packet { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let this = Pin::into_inner(self); - // Clear the flag bits, because we want the end writing bit be properly set - let flags = this.rc_and_flags.fetch_and(!(0b11 << 62), Ordering::AcqRel); + let rc = this.rc_and_waker.write(cx.waker().clone().into()); - // Drop the old waker - if (flags >> 62) == 0b11 { - unsafe { - core::ptr::drop_in_place((*this.waker.get()).as_mut_ptr()); - } - } + if rc == 0 { + // Synchronize the thread that last decremented the refcount. + // If we don't, we risk a race condition where we drop the packet, while the packet + // reference is still being used to take the waker. + this.rc_and_waker.wait_finalize(); - // Load in the start writing bit - let loaded = this.rc_and_flags.fetch_or(0b1 << 63, Ordering::AcqRel); - - if loaded & !(0b11 << 62) == 0 { // no more packets left, we don't need to write anything return Poll::Ready(()); } - unsafe { - *this.waker.get() = MaybeUninit::new(cx.waker().clone().into()); - } - - // Load in the end writing bit. - let loaded = this.rc_and_flags.fetch_or(0b1 << 62, Ordering::AcqRel); - - if loaded & !(0b11 << 62) == 0 { - // no more packets left, we wrote uselessly - // The waker will be freed in packet drop... - return Poll::Ready(()); - } - - // true indicates the waker was installed and we can go to sleep. Poll::Pending } } @@ -632,7 +697,7 @@ impl<'a, Perms: PacketPerms> Future for &'a Packet { impl Packet { /// Current reference count of the packet. pub fn rc(&self) -> usize { - self.rc_and_flags.load(Ordering::Relaxed) & !(0b11 << 62) + (self.rc_and_waker.acquire_rc()) as usize } unsafe fn on_output(&self, error: Option<(u64, NonZeroI32)>) -> Option { @@ -642,29 +707,28 @@ impl Packet { } } - let loaded = self.rc_and_flags.fetch_sub(1, Ordering::AcqRel); + let (prev, has_waker) = self.rc_and_waker.dec_rc(); - // Do nothing, because we are either: - // - // - Not the last packet (any of the first 62 bits set). - // - The waker was not fully written yet (the last 2 bits are not 0b11). This case will be - // handled by the polling thread appropriately. - if loaded != (0b11 << 62) + 1 { + // Do nothing, because we are not the last packet (any of the first 62 bits set). + if prev != 1 { return None; } - if self.rc_and_flags.fetch_and(!(0b11 << 62), Ordering::AcqRel) >> 62 == 0b11 { - // FIXME: dial this atomic codepath in, because we've seen uninitialized reads. - Some(core::ptr::read(self.waker.get()).assume_init()) + let ret = if has_waker { + self.rc_and_waker.take() } else { None - } + }; + + self.rc_and_waker.finalize(); + + ret } unsafe fn on_add_to_view(&self) { - let rc = self.rc_and_flags.fetch_add(1, Ordering::AcqRel) & !(0b11 << 62); + let rc = self.rc_and_waker.inc_rc(); if rc != 0 { - self.rc_and_flags.fetch_sub(1, Ordering::AcqRel); + self.rc_and_waker.dec_rc(); assert_eq!(rc, 0); } } @@ -688,8 +752,7 @@ impl Packet { pub unsafe fn new_hdr(vtbl: PacketVtblRef) -> Self { Packet { vtbl, - rc_and_flags: AtomicUsize::new(0), - waker: UnsafeCell::new(MaybeUninit::uninit()), + rc_and_waker: Default::default(), error_clamp: (!0u64).into(), min_error: 0.into(), } diff --git a/mfio/src/io/packet/view.rs b/mfio/src/io/packet/view.rs index 11b94e3..238bad0 100644 --- a/mfio/src/io/packet/view.rs +++ b/mfio/src/io/packet/view.rs @@ -381,7 +381,7 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> { assert!(pos < self.len()); // TODO: maybe relaxed is enough here? - self.pkt().rc_and_flags.fetch_add(1, Ordering::Release); + self.pkt().rc_and_waker.inc_rc(); let Self { pkt, @@ -425,7 +425,7 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> { /// /// Please see [`BoundPacketView::extract_packet`] documentation for details. pub unsafe fn extract_packet(&self, offset: u64, len: u64) -> Self { - self.pkt().rc_and_flags.fetch_add(1, Ordering::AcqRel); + self.pkt().rc_and_waker.inc_rc(); let Self { pkt, tag, start, ..