From 0aec9235db17fd57e300a31e2673399c2c73f51f Mon Sep 17 00:00:00 2001 From: Garret Kelly Date: Sat, 12 Nov 2022 18:52:22 -0500 Subject: [PATCH] RFC: Initial vsock stream support --- examples/Cargo.toml | 4 + examples/vsock.rs | 22 +++++ glommio/Cargo.toml | 1 + glommio/src/net/mod.rs | 2 + glommio/src/net/vsock.rs | 174 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+) create mode 100644 examples/vsock.rs create mode 100644 glommio/src/net/vsock.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c8a04a1f2..532fe760a 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -59,3 +59,7 @@ path = "hyper.rs" [[example]] name = "gate" path = "gate.rs" + +[[example]] +name = "vsock" +path = "vsock.rs" diff --git a/examples/vsock.rs b/examples/vsock.rs new file mode 100644 index 000000000..9cb6e56a1 --- /dev/null +++ b/examples/vsock.rs @@ -0,0 +1,22 @@ +use futures_lite::io::{AsyncReadExt, AsyncWriteExt}; +use glommio::{ + net::{VsockListener, VsockStream}, + prelude::*, +}; +use std::io::Result; + +fn main() -> Result<()> { + let executor = LocalExecutor::default(); + executor.run(async move { + let listener = VsockListener::bind_with_cid_port(u32::MAX, 1337).unwrap(); + let mut stream = listener.accept().await.unwrap().buffered(); + + let mut buf = [0u8; 1]; + while stream.read(&mut buf).await.unwrap() != 0 { + stream.write(&buf).await.unwrap(); + } + + println!("done!"); + }); + Ok(()) +} diff --git a/glommio/Cargo.toml b/glommio/Cargo.toml index 11b2dd813..6fc25f279 100755 --- a/glommio/Cargo.toml +++ b/glommio/Cargo.toml @@ -42,6 +42,7 @@ smallvec = { version = "1.7", features = ["union"] } socket2 = { version = "0.4", features = ["all"] } tracing = "0.1" typenum = "1.15" +vsock = "0.3.0" [dev-dependencies] fastrand = "1" diff --git a/glommio/src/net/mod.rs b/glommio/src/net/mod.rs index 7a05b14f7..1d28842b1 100644 --- a/glommio/src/net/mod.rs +++ b/glommio/src/net/mod.rs @@ -111,9 +111,11 @@ mod stream; mod tcp_socket; mod udp_socket; mod unix; +mod vsock; pub use self::{ stream::{Buffered, Preallocated}, tcp_socket::{AcceptedTcpStream, TcpListener, TcpStream}, udp_socket::UdpSocket, unix::{AcceptedUnixStream, UnixDatagram, UnixListener, UnixStream}, + vsock::{AcceptedVsockStream, VsockListener, VsockStream}, }; diff --git a/glommio/src/net/vsock.rs b/glommio/src/net/vsock.rs new file mode 100644 index 000000000..9d00aee00 --- /dev/null +++ b/glommio/src/net/vsock.rs @@ -0,0 +1,174 @@ +use super::stream::GlommioStream; +use crate::{ + net::{ + stream::{Buffered, NonBuffered, Preallocated, RxBuf}, + yolo_accept, + }, + reactor::Reactor, + GlommioError, +}; +use futures_lite::{ + io::{AsyncBufRead, AsyncRead, AsyncWrite}, + stream, +}; +use nix::sys::socket::SockAddr; +use pin_project_lite::pin_project; +use socket2::{Domain, Socket, Type}; +use std::{ + io, + os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, + pin::Pin, + rc::{Rc, Weak}, + task::{Context, Poll}, +}; + +type Result = crate::Result; + +#[derive(Debug)] +pub struct VsockListener { + reactor: Weak, + listener: vsock::VsockListener, +} + +impl VsockListener { + pub fn bind_with_cid_port(cid: u32, port: u32) -> Result { + let listener = vsock::VsockListener::bind_with_cid_port(cid, port)?; + + Ok(VsockListener { + reactor: Rc::downgrade(&crate::executor().reactor()), + listener, + }) + } + + pub async fn shared_accept(&self) -> Result { + let reactor = self.reactor.upgrade().unwrap(); + let raw_fd = self.listener.as_raw_fd(); + if let Some(r) = yolo_accept(raw_fd) { + match r { + Ok(fd) => { + return Ok(AcceptedVsockStream { fd }); + } + Err(err) => return Err(GlommioError::IoError(err)), + } + } + let source = reactor.accept(self.listener.as_raw_fd()); + let fd = source.collect_rw().await?; + Ok(AcceptedVsockStream { fd: fd as RawFd }) + } + + pub async fn accept(&self) -> Result { + Ok(self.shared_accept().await?.bind_to_executor()) + } + + pub fn incoming(&self) -> impl stream::Stream> + Unpin + '_ { + Box::pin(stream::unfold(self, |listener| async move { + Some((listener.accept().await, listener)) + })) + } +} + +#[derive(Copy, Clone, Debug)] +pub struct AcceptedVsockStream { + fd: RawFd, +} + +impl AcceptedVsockStream { + pub fn bind_to_executor(self) -> VsockStream { + VsockStream { + stream: unsafe { GlommioStream::from_raw_fd(self.fd) }, + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct VsockStream { + stream: GlommioStream, + } +} + +impl VsockStream { + pub async fn connect_with_cid_port(cid: u32, port: u32) -> Result { + let socket = Socket::new(Domain::VSOCK, Type::STREAM, None)?; + let addr = SockAddr::new_vsock(cid, port); + let reactor = crate::executor().reactor(); + let source = reactor.connect(socket.as_raw_fd(), addr); + source.collect_rw().await?; + + Ok(VsockStream { + stream: GlommioStream::from(socket), + }) + } + + pub fn buffered(self) -> VsockStream { + self.buffered_with(Preallocated::default()) + } + + pub fn buffered_with(self, buf: B) -> VsockStream { + VsockStream { + stream: self.stream.buffered_with(buf), + } + } +} + +impl AsyncBufRead for VsockStream { + fn poll_fill_buf<'a>( + self: Pin<&'a mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().stream.poll_fill_buf(cx) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + self.stream.consume(amt); + } +} + +impl AsyncRead for VsockStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for VsockStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_close(cx) + } +} + +#[derive(Debug)] +struct Stream(vsock::VsockStream); + +impl From for Stream { + fn from(socket: socket2::Socket) -> Stream { + Self(unsafe { vsock::VsockStream::from_raw_fd(socket.into_raw_fd()) }) + } +} + +impl AsRawFd for Stream { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +impl FromRawFd for Stream { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + Self(vsock::VsockStream::from_raw_fd(fd)) + } +}