Skip to content

Commit

Permalink
Avoid spawn_blocking and poll and asyncfd. (#6906)
Browse files Browse the repository at this point in the history
* Avoid spawn_blocking and poll and asyncfd.

Replace the Unix-specific AsyncFd and Windows poll code using
tokio::net::TcpStream.

prtest:full

* Delete dead code.
  • Loading branch information
sunfishcode authored Aug 24, 2023
1 parent c4aee34 commit 68ba286
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 131 deletions.
78 changes: 31 additions & 47 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ use crate::preview2::stream::TableStreamExt;
use crate::preview2::tcp::{HostTcpSocket, HostTcpState, TableTcpSocketExt};
use crate::preview2::{HostPollable, PollableFuture, WasiView};
use cap_net_ext::{Blocking, PoolExt, TcpListenerExt};
use cap_std::net::TcpListener;
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
use std::any::Any;
#[cfg(unix)]
use tokio::io::Interest;
#[cfg(not(unix))]
use tokio::task::spawn_blocking;

impl<T: WasiView> tcp::Host for T {
fn start_bind(
Expand All @@ -38,7 +36,9 @@ impl<T: WasiView> tcp::Host for T {
let binder = network.0.tcp_binder(local_address)?;

// Perform the OS bind call.
binder.bind_existing_tcp_listener(socket.tcp_socket())?;
binder.bind_existing_tcp_listener(
&*socket.tcp_socket().as_socketlike_view::<TcpListener>(),
)?;

let socket = table.get_tcp_socket_mut(this)?;
socket.tcp_state = HostTcpState::BindStarted;
Expand Down Expand Up @@ -67,19 +67,27 @@ impl<T: WasiView> tcp::Host for T {
remote_address: IpSocketAddress,
) -> Result<(), network::Error> {
let table = self.table_mut();
let socket = table.get_tcp_socket(this)?;
let r = {
let socket = table.get_tcp_socket(this)?;

match socket.tcp_state {
HostTcpState::Default => {}
HostTcpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
_ => return Err(ErrorCode::NotInProgress.into()),
}
match socket.tcp_state {
HostTcpState::Default => {}
HostTcpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
_ => return Err(ErrorCode::NotInProgress.into()),
}

let network = table.get_network(network)?;
let connecter = network.0.tcp_connecter(remote_address)?;
let network = table.get_network(network)?;
let connecter = network.0.tcp_connecter(remote_address)?;

// Do an OS `connect`. Our socket is non-blocking, so it'll either...
{
let view = &*socket.tcp_socket().as_socketlike_view::<TcpListener>();
let r = connecter.connect_existing_tcp_listener(view);
r
}
};

// Do an OS `connect`. Our socket is non-blocking, so it'll either...
match connecter.connect_existing_tcp_listener(socket.tcp_socket()) {
match r {
// succeed immediately,
Ok(()) => {
let socket = table.get_tcp_socket_mut(this)?;
Expand Down Expand Up @@ -155,7 +163,10 @@ impl<T: WasiView> tcp::Host for T {
_ => return Err(ErrorCode::NotInProgress.into()),
}

socket.tcp_socket().listen(None)?;
socket
.tcp_socket()
.as_socketlike_view::<TcpListener>()
.listen(None)?;

socket.tcp_state = HostTcpState::ListenStarted;

Expand Down Expand Up @@ -190,7 +201,10 @@ impl<T: WasiView> tcp::Host for T {
}

// Do the OS accept call.
let (connection, _addr) = socket.tcp_socket().accept_with(Blocking::No)?;
let (connection, _addr) = socket
.tcp_socket()
.as_socketlike_view::<TcpListener>()
.accept_with(Blocking::No)?;
let tcp_socket = HostTcpSocket::from_tcp_stream(connection)?;

let input_clone = tcp_socket.clone_inner();
Expand Down Expand Up @@ -412,43 +426,13 @@ impl<T: WasiView> tcp::Host for T {
}

// FIXME: Add `Interest::ERROR` when we update to tokio 1.32.
#[cfg(unix)]
let join = Box::pin(async move {
socket
.inner
.tcp_socket
.ready(Interest::READABLE | Interest::WRITABLE)
.await
.unwrap()
.retain_ready();
Ok(())
});

#[cfg(not(unix))]
let join = Box::pin(async move {
let clone = socket.clone_inner();
spawn_blocking(move || loop {
#[cfg(not(windows))]
let poll_flags = rustix::event::PollFlags::IN
| rustix::event::PollFlags::OUT
| rustix::event::PollFlags::ERR
| rustix::event::PollFlags::HUP;
// Windows doesn't appear to support `HUP`, or `ERR`
// combined with `IN`/`OUT`.
#[cfg(windows)]
let poll_flags = rustix::event::PollFlags::IN | rustix::event::PollFlags::OUT;
match rustix::event::poll(
&mut [rustix::event::PollFd::new(&clone.tcp_socket, poll_flags)],
-1,
) {
Ok(_) => break,
Err(Errno::INTR) => (),
Err(err) => Err(err).unwrap(),
}
})
.await
.unwrap();

.unwrap();
Ok(())
});

Expand Down
105 changes: 21 additions & 84 deletions crates/wasi/src/preview2/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::preview2::{HostInputStream, HostOutputStream, StreamState, Table, Tab
use bytes::{Bytes, BytesMut};
use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt};
use cap_std::net::{TcpListener, TcpStream};
use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike};
use io_lifetimes::AsSocketlike;
use std::io;
use std::sync::Arc;
Expand Down Expand Up @@ -41,9 +42,7 @@ pub(crate) enum HostTcpState {
/// A host TCP socket, plus associated bookkeeping.
///
/// The inner state is wrapped in an Arc because the same underlying socket is
/// used for implementing the stream types. Also needed for [`spawn_blocking`].
///
/// [`spawn_blocking`]: Self::spawn_blocking
/// used for implementing the stream types.
pub(crate) struct HostTcpSocket {
/// The part of a `HostTcpSocket` which is reference-counted so that we
/// can pass it to async tasks.
Expand All @@ -55,13 +54,7 @@ pub(crate) struct HostTcpSocket {

/// The inner reference-counted state of a `HostTcpSocket`.
pub(crate) struct HostTcpSocketInner {
/// On Unix-family platforms we can use `AsyncFd` for efficient polling.
#[cfg(unix)]
pub(crate) tcp_socket: tokio::io::unix::AsyncFd<cap_std::net::TcpListener>,

/// On non-Unix, we can use plain `poll`.
#[cfg(not(unix))]
pub(crate) tcp_socket: cap_std::net::TcpListener,
pub(crate) tcp_socket: tokio::net::TcpStream,
}

impl HostTcpSocket {
Expand All @@ -71,9 +64,12 @@ impl HostTcpSocket {
// by our async implementation.
let tcp_socket = TcpListener::new(family, Blocking::No)?;

// On Unix, pack it up in an `AsyncFd` so we can efficiently poll it.
#[cfg(unix)]
let tcp_socket = tokio::io::unix::AsyncFd::new(tcp_socket)?;
let tcp_socket = unsafe {
tokio::net::TcpStream::try_from(std::net::TcpStream::from_raw_socketlike(
tcp_socket.into_raw_socketlike(),
))
.unwrap()
};

Ok(Self {
inner: Arc::new(HostTcpSocketInner { tcp_socket }),
Expand All @@ -88,17 +84,20 @@ impl HostTcpSocket {
let fd = rustix::fd::OwnedFd::from(tcp_socket);
let tcp_socket = TcpListener::from(fd);

// On Unix, pack it up in an `AsyncFd` so we can efficiently poll it.
#[cfg(unix)]
let tcp_socket = tokio::io::unix::AsyncFd::new(tcp_socket)?;
let tcp_socket = unsafe {
tokio::net::TcpStream::try_from(std::net::TcpStream::from_raw_socketlike(
tcp_socket.into_raw_socketlike(),
))
.unwrap()
};

Ok(Self {
inner: Arc::new(HostTcpSocketInner { tcp_socket }),
tcp_state: HostTcpState::Default,
})
}

pub fn tcp_socket(&self) -> &cap_std::net::TcpListener {
pub fn tcp_socket(&self) -> &tokio::net::TcpStream {
self.inner.tcp_socket()
}

Expand All @@ -108,29 +107,11 @@ impl HostTcpSocket {
}

impl HostTcpSocketInner {
pub fn tcp_socket(&self) -> &cap_std::net::TcpListener {
pub fn tcp_socket(&self) -> &tokio::net::TcpStream {
let tcp_socket = &self.tcp_socket;

// Unpack the `AsyncFd`.
#[cfg(unix)]
let tcp_socket = tcp_socket.get_ref();

tcp_socket
}

/// Spawn a task on tokio's blocking thread for performing blocking
/// syscalls on the underlying [`cap_std::net::TcpListener`].
#[cfg(not(unix))]
pub(crate) async fn spawn_blocking<F, R>(self: &Arc<Self>, body: F) -> R
where
F: FnOnce(&cap_std::net::TcpListener) -> R + Send + 'static,
R: Send + 'static,
{
let s = Arc::clone(self);
tokio::task::spawn_blocking(move || body(s.tcp_socket()))
.await
.unwrap()
}
}

#[async_trait::async_trait]
Expand All @@ -150,30 +131,8 @@ impl HostInputStream for Arc<HostTcpSocketInner> {
}

async fn ready(&mut self) -> anyhow::Result<()> {
#[cfg(unix)]
{
self.tcp_socket.readable().await?.retain_ready();
Ok(())
}

#[cfg(not(unix))]
{
self.spawn_blocking(move |tcp_socket| {
match rustix::event::poll(
&mut [rustix::event::PollFd::new(
tcp_socket,
rustix::event::PollFlags::IN
| rustix::event::PollFlags::ERR
| rustix::event::PollFlags::HUP,
)],
-1,
) {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
})
.await
}
self.tcp_socket.readable().await?;
Ok(())
}
}

Expand All @@ -192,30 +151,8 @@ impl HostOutputStream for Arc<HostTcpSocketInner> {
}

async fn ready(&mut self) -> anyhow::Result<()> {
#[cfg(unix)]
{
self.tcp_socket.writable().await?.retain_ready();
Ok(())
}

#[cfg(not(unix))]
{
self.spawn_blocking(move |tcp_socket| {
match rustix::event::poll(
&mut [rustix::event::PollFd::new(
tcp_socket,
rustix::event::PollFlags::OUT
| rustix::event::PollFlags::ERR
| rustix::event::PollFlags::HUP,
)],
-1,
) {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
})
.await
}
self.tcp_socket.writable().await?;
Ok(())
}
}

Expand Down

0 comments on commit 68ba286

Please sign in to comment.