Skip to content

Commit

Permalink
Merge pull request #998 from microsoft/bugfix-inetstack-shared-waker
Browse files Browse the repository at this point in the history
[inetstack] Remove spurious wake ups
  • Loading branch information
iyzhang authored Nov 10, 2023
2 parents de9615f + a885972 commit ef7e1f1
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 46 deletions.
8 changes: 4 additions & 4 deletions src/rust/inetstack/protocols/arp/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ impl<const N: usize> SharedArpPeer<N> {

/// Background task that cleans up the ARP cache from time to time.
async fn background(&mut self) {
let yielder: Yielder = Yielder::new();
loop {
let yielder: Yielder = Yielder::new();
match self.runtime.get_timer().wait(Self::ARP_CLEANUP_TIMEOUT, yielder).await {
match self.runtime.get_timer().wait(Self::ARP_CLEANUP_TIMEOUT, &yielder).await {
Ok(()) => continue,
Err(_) => break,
}
Expand Down Expand Up @@ -276,13 +276,13 @@ impl<const N: usize> SharedArpPeer<N> {
// > The frequency of the ARP request is very close to one per
// > second, the maximum suggested by [RFC1122].
let result = {
let yielder: Yielder = Yielder::new();
for i in 0..self.arp_config.get_retry_count() + 1 {
self.transport.transmit(Box::new(msg.clone()));
let yielder: Yielder = Yielder::new();
let timer = self
.runtime
.get_timer()
.wait(self.arp_config.get_request_timeout(), yielder);
.wait(self.arp_config.get_request_timeout(), &yielder);

match arp_response.with_timeout(timer).await {
Ok(link_addr) => {
Expand Down
2 changes: 1 addition & 1 deletion src/rust/inetstack/protocols/icmpv4/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ impl<const N: usize> SharedIcmpv4Peer<N> {
};
let yielder: Yielder = Yielder::new();
let clock_ref: SharedTimer = self.runtime.get_timer();
let timer = clock_ref.wait(timeout, yielder);
let timer = clock_ref.wait(timeout, &yielder);
match rx.fuse().with_timeout(timer).await? {
// Request completed successfully.
Ok(_) => Ok(self.runtime.get_now() - t0),
Expand Down
13 changes: 9 additions & 4 deletions src/rust/inetstack/protocols/tcp/active_open.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use crate::{
scheduler::{
TaskHandle,
Yielder,
YielderHandle,
},
};
use ::libc::{
Expand Down Expand Up @@ -78,6 +79,7 @@ pub struct ActiveOpenSocket<const N: usize> {
arp: SharedArpPeer<N>,
result: AsyncValue<Result<SharedControlBlock<N>, Fail>>,
handle: Option<TaskHandle>,
yielder_handle: YielderHandle,
}

#[derive(Clone)]
Expand All @@ -98,6 +100,8 @@ impl<const N: usize> SharedActiveOpenSocket<N> {
local_link_addr: MacAddress,
arp: SharedArpPeer<N>,
) -> Result<Self, Fail> {
let yielder: Yielder = Yielder::new();
let yielder_handle: YielderHandle = yielder.get_handle();
let mut me: Self = Self(SharedObject::<ActiveOpenSocket<N>>::new(ActiveOpenSocket::<N> {
local_isn,
local,
Expand All @@ -109,11 +113,12 @@ impl<const N: usize> SharedActiveOpenSocket<N> {
arp,
result: AsyncValue::<Result<SharedControlBlock<N>, Fail>>::default(),
handle: None,
yielder_handle,
}));

let handle: TaskHandle = runtime.insert_background_coroutine(
"Inetstack::TCP::activeopen::background",
Box::pin(me.clone().background()),
Box::pin(me.clone().background(yielder)),
)?;
me.handle = Some(handle);
// TODO: Add fast path here when remote is already in the ARP cache (and subtract one retry).
Expand Down Expand Up @@ -226,13 +231,14 @@ impl<const N: usize> SharedActiveOpenSocket<N> {
None,
);
self.result.set(Ok(cb));
self.yielder_handle.wake_with(Ok(()));
let handle: TaskHandle = self.handle.take().expect("We should have allocated a background task");
if let Err(e) = self.runtime.remove_background_coroutine(&handle) {
panic!("Failed to remove active open coroutine (error={:?}", e);
}
}

async fn background(mut self) {
async fn background(mut self, yielder: Yielder) {
let handshake_retries: usize = self.tcp_config.get_handshake_retries();
let handshake_timeout = self.tcp_config.get_handshake_timeout();
for _ in 0..handshake_retries {
Expand Down Expand Up @@ -266,8 +272,7 @@ impl<const N: usize> SharedActiveOpenSocket<N> {
};
self.transport.transmit(Box::new(segment));
let clock_ref: SharedTimer = self.runtime.get_timer();
let yielder: Yielder = Yielder::new();
if let Err(e) = clock_ref.wait(handshake_timeout, yielder).await {
if let Err(e) = clock_ref.wait(handshake_timeout, &yielder).await {
self.result.set(Err(e));
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use ::futures::future::{
};
use ::std::time::Instant;

pub async fn acknowledger<const N: usize>(mut cb: SharedControlBlock<N>) -> Result<!, Fail> {
pub async fn acknowledger<const N: usize>(mut cb: SharedControlBlock<N>, yielder: Yielder) -> Result<!, Fail> {
loop {
// TODO: Implement TCP delayed ACKs, subject to restrictions from RFC 1122
// - TCP should implement a delayed ACK
Expand All @@ -33,16 +33,20 @@ pub async fn acknowledger<const N: usize>(mut cb: SharedControlBlock<N>) -> Resu
futures::pin_mut!(ack_deadline_changed);

let clock_ref: SharedTimer = cb.get_timer();
let yielder: Yielder = Yielder::new();
let ack_future = match deadline {
Some(t) => Either::Left(clock_ref.wait_until(t, yielder).fuse()),
Some(t) => Either::Left(clock_ref.wait_until(t, &yielder).fuse()),
None => Either::Right(future::pending()),
};
futures::pin_mut!(ack_future);

futures::select_biased! {
_ = ack_deadline_changed => continue,
_ = ack_future => {
match cb.get_ack_deadline().get() {
Some(timeout) if timeout > cb.get_now() => continue,
None => continue,
_ => {},
}
cb.send_ack();
},
}
Expand Down
18 changes: 9 additions & 9 deletions src/rust/inetstack/protocols/tcp/established/background/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,32 @@ use self::{
use crate::{
inetstack::protocols::tcp::established::ctrlblk::SharedControlBlock,
runtime::QDesc,
scheduler::Yielder,
};
use ::futures::{
channel::mpsc,
FutureExt,
};

pub async fn background<const N: usize>(
cb: SharedControlBlock<N>,
fd: QDesc,
_dead_socket_tx: mpsc::UnboundedSender<QDesc>,
) {
let acknowledger = acknowledger(cb.clone()).fuse();
pub async fn background<const N: usize>(cb: SharedControlBlock<N>, _dead_socket_tx: mpsc::UnboundedSender<QDesc>) {
let yielder_acknowledger: Yielder = Yielder::new();
let acknowledger = acknowledger(cb.clone(), yielder_acknowledger).fuse();
futures::pin_mut!(acknowledger);

let retransmitter = retransmitter(cb.clone()).fuse();
let yielder_retransmitter: Yielder = Yielder::new();
let retransmitter = retransmitter(cb.clone(), yielder_retransmitter).fuse();
futures::pin_mut!(retransmitter);

let sender = sender(cb.clone()).fuse();
let yielder_sender: Yielder = Yielder::new();
let sender = sender(cb.clone(), yielder_sender).fuse();
futures::pin_mut!(sender);

let r = futures::select_biased! {
r = acknowledger => r,
r = retransmitter => r,
r = sender => r,
};
error!("Connection (fd {:?}) terminated: {:?}", fd, r);
error!("Connection terminated: {:?}", r);

// TODO Properly clean up Peer state for this connection.
// dead_socket_tx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use ::std::time::{
Instant,
};

pub async fn retransmitter<const N: usize>(mut cb: SharedControlBlock<N>) -> Result<!, Fail> {
pub async fn retransmitter<const N: usize>(mut cb: SharedControlBlock<N>, yielder: Yielder) -> Result<!, Fail> {
loop {
// Pin future for timeout retransmission.
let mut rtx_deadline_watched: SharedWatchedValue<Option<Instant>> = cb.watch_retransmit_deadline();
Expand All @@ -29,9 +29,8 @@ pub async fn retransmitter<const N: usize>(mut cb: SharedControlBlock<N>) -> Res
let rtx_deadline_changed = rtx_deadline_watched.watch(rtx_yielder).fuse();
futures::pin_mut!(rtx_deadline_changed);
let clock_ref: SharedTimer = cb.get_timer();
let yielder: Yielder = Yielder::new();
let rtx_future = match rtx_deadline {
Some(t) => Either::Left(clock_ref.wait_until(t, yielder).fuse()),
Some(t) => Either::Left(clock_ref.wait_until(t, &yielder).fuse()),
None => Either::Right(future::pending()),
};
futures::pin_mut!(rtx_future);
Expand All @@ -52,11 +51,17 @@ pub async fn retransmitter<const N: usize>(mut cb: SharedControlBlock<N>) -> Res
}
futures::pin_mut!(rtx_fast_retransmit_changed);

// Since these futures all share a single waker bit, they are all woken whenever one of them triggers.
futures::select_biased! {
_ = rtx_deadline_changed => continue,
_ = rtx_fast_retransmit_changed => continue,
_ = rtx_future => {
trace!("Retransmission Timer Expired");
match cb.get_retransmit_deadline() {
Some(timeout) if timeout > cb.get_now() => continue,
None => continue,
_ => {},
}

// Notify congestion control about RTO.
// TODO: Is this the best place for this?
// TODO: Why call into ControlBlock to get SND.UNA when congestion_control_on_rto() has access to it?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use ::std::{
time::Duration,
};

pub async fn sender<const N: usize>(mut cb: SharedControlBlock<N>) -> Result<!, Fail> {
pub async fn sender<const N: usize>(mut cb: SharedControlBlock<N>, yielder: Yielder) -> Result<!, Fail> {
'top: loop {
// First, check to see if there's any unsent data.
// TODO: Change this to just look at the unsent queue to see if it is empty or not.
Expand Down Expand Up @@ -85,12 +85,11 @@ pub async fn sender<const N: usize>(mut cb: SharedControlBlock<N>) -> Result<!,
// TODO: Use the correct PERSIST mode timer here.
let mut timeout: Duration = Duration::from_secs(1);
loop {
let yielder: Yielder = Yielder::new();
let clock_ref: SharedTimer = cb.get_timer();

futures::select_biased! {
_ = win_sz_changed => continue 'top,
_ = clock_ref.wait(timeout, yielder).fuse() => {
_ = clock_ref.wait(timeout, &yielder).fuse() => {
timeout *= 2;
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/rust/inetstack/protocols/tcp/established/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ pub struct EstablishedSocket<const N: usize> {
impl<const N: usize> EstablishedSocket<N> {
pub fn new(
cb: SharedControlBlock<N>,
qd: QDesc,
dead_socket_tx: mpsc::UnboundedSender<QDesc>,
mut runtime: SharedDemiRuntime,
) -> Result<Self, Fail> {
// TODO: Maybe add the queue descriptor here.
let handle: TaskHandle = runtime.insert_background_coroutine(
"Inetstack::TCP::established::background",
Box::pin(background::background(cb.clone(), qd, dead_socket_tx)),
Box::pin(background::background(cb.clone(), dead_socket_tx)),
)?;
Ok(Self {
cb,
Expand Down
15 changes: 10 additions & 5 deletions src/rust/inetstack/protocols/tcp/passive_open.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use crate::{
scheduler::{
TaskHandle,
Yielder,
YielderHandle,
},
};
use ::libc::{
Expand Down Expand Up @@ -80,6 +81,7 @@ struct InflightAccept {
remote_window_scale: Option<u8>,
mss: usize,
handle: TaskHandle,
yielder_handle: YielderHandle,
}

#[derive(Default)]
Expand Down Expand Up @@ -212,7 +214,8 @@ impl<const N: usize> SharedPassiveSocket<N> {
local_window_scale, remote_window_scale
);

if let Some(inflight) = self.inflight.remove(&remote) {
if let Some(mut inflight) = self.inflight.remove(&remote) {
inflight.yielder_handle.wake_with(Ok(()));
if let Err(e) = self.runtime.remove_background_coroutine(&inflight.handle) {
panic!("Failed to remove inflight accept (error={:?})", e);
}
Expand Down Expand Up @@ -259,7 +262,9 @@ impl<const N: usize> SharedPassiveSocket<N> {
let local: SocketAddrV4 = self.local.clone();
let local_isn = self.isn_generator.generate(&local, &remote);
let remote_isn = header.seq_num;
let future = self.clone().background(remote, remote_isn, local_isn);
let yielder: Yielder = Yielder::new();
let yielder_handle: YielderHandle = yielder.get_handle();
let future = self.clone().background(remote, remote_isn, local_isn, yielder);
let handle: TaskHandle = self
.runtime
.insert_background_coroutine("Inetstack::TCP::passiveopen::background", Box::pin(future))?;
Expand All @@ -286,12 +291,13 @@ impl<const N: usize> SharedPassiveSocket<N> {
remote_window_scale,
mss,
handle,
yielder_handle,
};
self.inflight.insert(remote, accept);
Ok(())
}

async fn background(mut self, remote: SocketAddrV4, remote_isn: SeqNumber, local_isn: SeqNumber) {
async fn background(mut self, remote: SocketAddrV4, remote_isn: SeqNumber, local_isn: SeqNumber, yielder: Yielder) {
let handshake_retries: usize = self.tcp_config.get_handshake_retries();
let handshake_timeout: Duration = self.tcp_config.get_handshake_timeout();

Expand Down Expand Up @@ -327,8 +333,7 @@ impl<const N: usize> SharedPassiveSocket<N> {
};
self.transport.transmit(Box::new(segment));
let clock_ref: SharedTimer = self.runtime.get_timer();
let yielder: Yielder = Yielder::new();
if let Err(e) = clock_ref.wait(handshake_timeout, yielder).await {
if let Err(e) = clock_ref.wait(handshake_timeout, &yielder).await {
self.ready.push_err(e);
return;
}
Expand Down
3 changes: 1 addition & 2 deletions src/rust/inetstack/protocols/tcp/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl<const N: usize> SharedTcpPeer<N> {
let new_qd: QDesc = self.runtime.alloc_queue::<SharedTcpQueue<N>>(new_queue.clone());
// Set up established socket data structure.
let established: EstablishedSocket<N> =
EstablishedSocket::new(cb, new_qd, self.dead_socket_tx.clone(), self.runtime.clone())?;
EstablishedSocket::new(cb, self.dead_socket_tx.clone(), self.runtime.clone())?;
let local: SocketAddrV4 = established.cb.get_local();
let remote: SocketAddrV4 = established.cb.get_remote();
// Set the socket in the new queue to established
Expand Down Expand Up @@ -358,7 +358,6 @@ impl<const N: usize> SharedTcpPeer<N> {
let cb: SharedControlBlock<N> = socket.get_result(yielder).await?;
let new_socket = Socket::Established(EstablishedSocket::new(
cb,
qd,
self.dead_socket_tx.clone(),
self.runtime.clone(),
)?);
Expand Down
10 changes: 5 additions & 5 deletions src/rust/runtime/timer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ impl SharedTimer {
self.now
}

pub async fn wait(self, timeout: Duration, yielder: Yielder) -> Result<(), Fail> {
pub async fn wait(self, timeout: Duration, yielder: &Yielder) -> Result<(), Fail> {
let now: Instant = self.now;
self.wait_until(now + timeout, yielder).await
self.wait_until(now + timeout, &yielder).await
}

pub async fn wait_until(mut self, expiry: Instant, yielder: Yielder) -> Result<(), Fail> {
pub async fn wait_until(mut self, expiry: Instant, yielder: &Yielder) -> Result<(), Fail> {
let entry = TimerQueueEntry {
expiry,
yielder: yielder.get_handle(),
Expand Down Expand Up @@ -206,7 +206,7 @@ mod tests {
let timer_ref: SharedTimer = timer.clone();
let yielder: Yielder = Yielder::new();

let wait_future1 = timer_ref.wait(Duration::from_secs(2), yielder);
let wait_future1 = timer_ref.wait(Duration::from_secs(2), &yielder);
futures::pin_mut!(wait_future1);

crate::ensure_eq!(Future::poll(Pin::new(&mut wait_future1), &mut ctx).is_pending(), true);
Expand All @@ -217,7 +217,7 @@ mod tests {
let timer_ref2: SharedTimer = timer.clone();
let yielder2: Yielder = Yielder::new();
crate::ensure_eq!(Future::poll(Pin::new(&mut wait_future1), &mut ctx).is_pending(), true);
let wait_future2 = timer_ref2.wait(Duration::from_secs(1), yielder2);
let wait_future2 = timer_ref2.wait(Duration::from_secs(1), &yielder2);
futures::pin_mut!(wait_future2);

crate::ensure_eq!(Future::poll(Pin::new(&mut wait_future1), &mut ctx).is_pending(), true);
Expand Down
4 changes: 1 addition & 3 deletions src/rust/scheduler/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ impl YielderHandle {
"wake_with(): already scheduled, overwriting result (old={:?})",
old_result
);
}

if let Some(waker) = self.waker_handle.borrow_mut().take() {
} else if let Some(waker) = self.waker_handle.borrow_mut().take() {
waker.wake();
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/rust/scheduler/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ impl Scheduler {
// Get the pinned ref.
let pinned_ptr = {
let pin_slab_index: usize = Scheduler::get_pin_slab_index(waker_page_index, waker_page_offset);
let pinned_ref: Pin<&mut Box<dyn Task>> = self.tasks.get_pin_mut(pin_slab_index).unwrap();
let pinned_ref: Pin<&mut Box<dyn Task>> = self
.tasks
.get_pin_mut(pin_slab_index)
.expect(format!("Invalid offset: {:?}", pin_slab_index).as_str());
let pinned_ptr = unsafe { Pin::into_inner_unchecked(pinned_ref) as *mut _ };
pinned_ptr
};
Expand Down

0 comments on commit ef7e1f1

Please sign in to comment.