Skip to content

Commit

Permalink
Fix IOCP on real windows
Browse files Browse the repository at this point in the history
  • Loading branch information
h33p committed Nov 3, 2023
1 parent b27521b commit 777a524
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 41 deletions.
79 changes: 69 additions & 10 deletions mfio-rt/src/native/impls/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use mfio::backend::handle::{EventWaker, EventWakerOwner};
use mfio::backend::*;
use mfio::error::State;
use mfio::io::{Read as RdPerm, Write as WrPerm, *};
use mfio::mferr;
use mfio::tarc::BaseArc;
use parking_lot::Mutex;
use slab::Slab;
Expand All @@ -30,7 +31,8 @@ use ::windows::Win32::Foundation::{
WIN32_ERROR,
};
use ::windows::Win32::Networking::WinSock::{
AcceptEx, WSAGetLastError, WSARecv, WSASend, SOCKADDR, SOCKET, SOCK_STREAM, WSABUF,
setsockopt, AcceptEx, WSAGetLastError, WSARecv, WSASend, SOCKADDR, SOCKET, SOCK_STREAM,
SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSABUF,
};
use ::windows::Win32::Storage::FileSystem::{
ReadFile, SetFileCompletionNotificationModes, WriteFile,
Expand All @@ -54,8 +56,53 @@ pub use file::FileWrapper;
pub use tcp_listener::TcpListener;
pub use tcp_stream::{TcpConnectFuture, TcpStream};

fn win32_error() -> WIN32_ERROR {
WIN32_ERROR((Error::from_win32().code().0 & 0xFFFF) as _)
fn expand_error(e: Error) -> WIN32_ERROR {
let code = if e.code().0 == 0 {
Error::from_win32().code().0
} else {
e.code().0
};

WIN32_ERROR((code & 0xFFFF) as _)
}

fn socket_accepted(
s: usize,
l: SOCKET,
streams: &mut Slab<StreamInner>,
) -> Result<usize, mfio::error::Error> {
if unsafe {
setsockopt(
SOCKET(streams.get_mut(s).unwrap().socket.as_raw_socket() as _),
SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT,
Some(&l.0.to_ne_bytes()),
)
} != 0
{
Err(mferr!(500, Io, Other, Network))
} else {
Ok(s)
}
}

fn socket_connected(
s: usize,
streams: &mut Slab<StreamInner>,
) -> Result<usize, mfio::error::Error> {
if unsafe {
setsockopt(
SOCKET(streams.get_mut(s).unwrap().socket.as_raw_socket() as _),
SOL_SOCKET,
SO_UPDATE_CONNECT_CONTEXT,
None,
)
} != 0
{
Err(mferr!(500, Io, Other, Network))
} else {
Ok(s)
}
}

#[repr(C)]
Expand Down Expand Up @@ -217,6 +264,7 @@ impl OperationMode {

pub fn on_processed(
self,
handle: HANDLE,
res: std::io::Result<usize>,
connections: &mut Slab<TcpGetSock>,
streams: &mut Slab<StreamInner>,
Expand Down Expand Up @@ -275,7 +323,12 @@ impl OperationMode {
let conn = connections.get_mut(conn_id).unwrap();

match res {
Ok(_) => conn.res = Some(Ok(conn.socket_idx.take().unwrap())),
Ok(_) => {
conn.res = Some(
Ok(conn.socket_idx.take().unwrap())
.and_then(|v| socket_connected(v, streams)),
)
}
Err(e) => {
conn.res = {
streams.remove(conn.socket_idx.take().unwrap());
Expand All @@ -295,7 +348,12 @@ impl OperationMode {
conn.tmp_addr = Some(tmp_addr);

match res {
Ok(_) => conn.res = Some(Ok(conn.socket_idx.take().unwrap())),
Ok(_) => {
conn.res = Some(
Ok(conn.socket_idx.take().unwrap())
.and_then(|v| socket_accepted(v, SOCKET(handle.0 as _), streams)),
)
}
Err(e) => {
conn.res = {
streams.remove(conn.socket_idx.take().unwrap());
Expand Down Expand Up @@ -348,24 +406,25 @@ impl Operation {
let res = if res.is_ok() {
Ok(transferred as usize)
} else {
match win32_error() {
match expand_error(Error::OK) {
ERROR_IO_INCOMPLETE | ERROR_HANDLE_EOF | ERROR_NO_DATA => Ok(transferred as usize),
_ => Err(std::io::Error::from(std::io::ErrorKind::Other)),
}
};

self.mode
.on_processed(res, connections, streams, deferred_pkts)
.on_processed(header.handle, res, connections, streams, deferred_pkts)
}

pub fn on_error(
self,
mut self,
connections: &mut Slab<TcpGetSock>,
streams: &mut Slab<StreamInner>,
deferred_pkts: &mut DeferredPackets,
) {
log::trace!("On error");
self.mode.on_processed(
self.header.get_mut().handle,
Err(std::io::ErrorKind::Other.into()),
connections,
streams,
Expand Down Expand Up @@ -520,8 +579,8 @@ impl IocpState {
header.overlapped.InternalHigh = 0;
header.idx = idx;
match entry.mode.submit_op(header) {
Err(_) => {
match win32_error() {
Err(e) => {
match expand_error(e) {
ERROR_INVALID_USER_BUFFER | ERROR_NOT_ENOUGH_MEMORY => {
if queue_back {
self.pending_ops.push_back(Ok(idx))
Expand Down
49 changes: 29 additions & 20 deletions mfio-rt/src/native/impls/iocp/tcp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use crate::{Shutdown, TcpStreamHandle};

use ::windows::Win32::Foundation::HANDLE;
use ::windows::Win32::Networking::WinSock::{
shutdown, WSAGetLastError, SD_BOTH, SD_RECEIVE, SD_SEND, SOCKET, WSABUF,
shutdown, WSAGetLastError, SD_BOTH, SD_RECEIVE, SD_SEND, SOCKET, WSABUF, WSAECONNRESET,
WSAENOTCONN,
};
use ::windows::Win32::System::IO::CancelIoEx;
use ::windows::Win32::System::IO::OVERLAPPED;
Expand Down Expand Up @@ -65,10 +66,32 @@ pub struct StreamInner {

impl Drop for StreamInner {
fn drop(&mut self) {
if unsafe { shutdown(SOCKET(self.socket.as_raw_socket() as _), SD_BOTH) } != 0 {
log::warn!("Could not shutdown stream: {:?}", unsafe {
WSAGetLastError()
});
let _ = self.shutdown(Shutdown::Both);
}
}

impl StreamInner {
fn shutdown(&self, how: Shutdown) -> Result<(), Error> {
let ret = unsafe {
shutdown(
SOCKET(self.socket.as_raw_socket() as _),
match how {
Shutdown::Read => SD_RECEIVE,
Shutdown::Write => SD_SEND,
Shutdown::Both => SD_BOTH,
},
)
};
if ret != 0 {
match unsafe { WSAGetLastError() } {
WSAECONNRESET => Ok(()),
v => {
log::error!("Unable to shutdown stream: {ret} {v:?}");
Err(mferr!(500, Io, Other, Network))
}
}
} else {
Ok(())
}
}
}
Expand Down Expand Up @@ -115,21 +138,7 @@ impl TcpStreamHandle for TcpStream {
.streams
.get(self.idx)
.ok_or_else(|| io_err(State::NotFound))?;
if unsafe {
shutdown(
SOCKET(stream.socket.as_raw_socket() as _),
match how {
Shutdown::Read => SD_RECEIVE,
Shutdown::Write => SD_SEND,
Shutdown::Both => SD_BOTH,
},
)
} != 0
{
Err(mferr!(500, Io, Other, Network))
} else {
Ok(())
}
stream.shutdown(how)
}
}

Expand Down
8 changes: 4 additions & 4 deletions mfio-rt/src/native/impls/windows_extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ impl CSockAddr {
pub fn to_socket_addr(self) -> SocketAddr {
match unsafe { self.generic.sa_family } {
AF_INET => SocketAddr::V4(SocketAddrV4::new(
unsafe { self.ipv4.sin_addr.S_un.S_addr }.into(),
unsafe { self.ipv4.sin_port },
u32::from_be(unsafe { self.ipv4.sin_addr.S_un.S_addr }).into(),
u16::from_be(unsafe { self.ipv4.sin_port }),
)),
AF_INET6 => SocketAddr::V6(SocketAddrV6::new(
unsafe { self.ipv6.sin6_addr.u.Word }.into(),
unsafe { self.ipv6.sin6_port },
unsafe { self.ipv6.sin6_flowinfo },
u16::from_be(unsafe { self.ipv6.sin6_port }),
u32::from_be(unsafe { self.ipv6.sin6_flowinfo }),
unsafe { self.ipv6.Anonymous.sin6_scope_id },
)),
_ => unreachable!(),
Expand Down
14 changes: 7 additions & 7 deletions mfio-rt/src/test_suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl<'a, T: Tcp> NetTestRun<'a, T> {

let _sem = TCP_SEM.acquire().await;

let listener = TcpListener::bind("0.0.0.0:0").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

let jh = std::thread::spawn(move || {
Expand All @@ -185,7 +185,7 @@ impl<'a, T: Tcp> NetTestRun<'a, T> {

let _sem = TCP_SEM.acquire().await;

let listener = self.rt.bind("0.0.0.0:0").await.unwrap();
let listener = self.rt.bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

let jh = std::thread::spawn(move || {
Expand All @@ -194,7 +194,7 @@ impl<'a, T: Tcp> NetTestRun<'a, T> {

let mut listener = pin!(listener);

let _ = listener.next().await.unwrap();
let v = listener.next().await.unwrap();

Check warning on line 197 in mfio-rt/src/test_suite.rs

View workflow job for this annotation

GitHub Actions / clippy

unused variable: `v`

warning: unused variable: `v` --> mfio-rt/src/test_suite.rs:197:13 | 197 | let v = listener.next().await.unwrap(); | ^ help: if this is intentional, prefix it with an underscore: `_v` | = note: `#[warn(unused_variables)]` on by default

Check warning on line 197 in mfio-rt/src/test_suite.rs

View workflow job for this annotation

GitHub Actions / clippy

unused variable: `v`

warning: unused variable: `v` --> mfio-rt/src/test_suite.rs:197:13 | 197 | let v = listener.next().await.unwrap(); | ^ help: if this is intentional, prefix it with an underscore: `_v` | = note: `#[warn(unused_variables)]` on by default

jh.join().unwrap();
}
Expand All @@ -206,7 +206,7 @@ impl<'a, T: Tcp> NetTestRun<'a, T> {
self.ctx.files.iter().map(move |(name, data)| async move {
let _sem = TCP_SEM.acquire().await;

let listener = TcpListener::bind("0.0.0.0:0").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

let (tx, rx) = flume::bounded(1);
Expand Down Expand Up @@ -242,7 +242,7 @@ impl<'a, T: Tcp> NetTestRun<'a, T> {
self.ctx.files.iter().map(move |(name, data)| async move {
let _sem = TCP_SEM.acquire().await;

let listener = TcpListener::bind("0.0.0.0:0").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

let (tx, rx) = flume::bounded(1);
Expand Down Expand Up @@ -281,7 +281,7 @@ impl<'a, T: Tcp> NetTestRun<'a, T> {
self.ctx.files.iter().map(move |(name, data)| async move {
let _sem = TCP_SEM.acquire().await;

let listener = TcpListener::bind("0.0.0.0:0").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

let (tx, rx) = flume::bounded(1);
Expand Down Expand Up @@ -333,7 +333,7 @@ impl<'a, T: Tcp> NetTestRun<'a, T> {
self.ctx.files.iter().map(move |(name, data)| async move {
let _sem = TCP_SEM.acquire().await;

let listener = self.rt.bind("0.0.0.0:0").await.unwrap();
let listener = self.rt.bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

let (tx, rx) = flume::bounded(1);
Expand Down

0 comments on commit 777a524

Please sign in to comment.