Skip to content

Commit

Permalink
rename async-tokio feature to tokio
Browse files Browse the repository at this point in the history
  • Loading branch information
imDema committed Mar 14, 2024
1 parent 41ec4f7 commit 17a2b60
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ readme = "README.md"
default = ["clap", "ssh", "timestamp"]
timestamp = []
ssh = ["ssh2", "whoami", "shell-escape", "sha2", "base64"]
async-tokio = ["tokio", "futures", "tokio/net", "tokio/io-util", "tokio/time", "tokio/rt-multi-thread", "tokio/macros"]
tokio = ["dep:tokio", "futures", "tokio/net", "tokio/io-util", "tokio/time", "tokio/rt-multi-thread", "tokio/macros"]
profiler = []

[dependencies]
Expand Down
4 changes: 2 additions & 2 deletions examples/wordcount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
// s.split_whitespace().map(str::to_lowercase).collect()
// }

#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
fn main() {
tracing_subscriber::fmt::init();
let (config, args) = RuntimeConfig::from_args();
Expand Down Expand Up @@ -42,7 +42,7 @@ fn main() {
}
}

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
#[tokio::main()]
async fn main() {
tracing_subscriber::fmt::init();
Expand Down
4 changes: 2 additions & 2 deletions examples/wordcount_assoc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use noir_compute::prelude::*;
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;

#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
fn main() {
tracing_subscriber::fmt::init();

Expand Down Expand Up @@ -45,7 +45,7 @@ fn main() {
}
}

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
#[tokio::main(flavor = "current_thread")]
async fn main() {
tracing_subscriber::fmt::init();
Expand Down
2 changes: 1 addition & 1 deletion src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl StreamContext {
}

/// Start the computation. Await on the returned future to actually start the computation.
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
pub async fn execute(self) {
let mut env = self.inner.lock();
info!("starting execution ({} blocks)", env.block_count);
Expand Down
8 changes: 4 additions & 4 deletions src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ pub(crate) use topology::*;
use crate::operator::StreamElement;
use crate::scheduler::{BlockId, HostId, ReplicaId};

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
mod tokio;
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::*;

#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
mod sync;
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
use sync::*;

mod network_channel;
Expand Down
8 changes: 4 additions & 4 deletions src/network/sync/remote.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use once_cell::sync::Lazy;
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
use std::io::Read;
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
use std::io::Write;

use bincode::config::{FixintEncoding, RejectTrailing, WithOtherIntEncoding, WithOtherTrailing};
Expand Down Expand Up @@ -43,7 +43,7 @@ struct MessageHeader {
/// The network protocol works as follow:
/// - send a `MessageHeader` serialized with bincode with `FixintEncoding`
/// - send the message
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
pub(crate) fn remote_send<T: ExchangeData, W: Write>(
msg: NetworkMessage<T>,
dest: ReceiverEndpoint,
Expand Down Expand Up @@ -97,7 +97,7 @@ pub(crate) fn remote_send<T: ExchangeData, W: Write>(
/// last message.
///
/// The message won't be deserialized, use `deserialize()`.
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
pub(crate) fn remote_recv<T: ExchangeData, R: Read>(
coord: DemuxCoord,
reader: &mut R,
Expand Down
12 changes: 6 additions & 6 deletions src/network/tokio/demultiplexer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::io::AsyncWriteExt;
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::net::{TcpListener, TcpStream};
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::task::JoinHandle;

use anyhow::anyhow;
Expand All @@ -27,7 +27,7 @@ pub(crate) struct DemuxHandle<In: Send + 'static> {
tx_senders: UnboundedSender<(ReceiverEndpoint, Sender<NetworkMessage<In>>)>,
}

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
impl<In: ExchangeData> DemuxHandle<In> {
/// Construct a new `DemultiplexingReceiver` for a block.
///
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<In: ExchangeData> DemuxHandle<In> {
}

/// Bind the socket of this demultiplexer.
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
async fn bind_remotes<In: ExchangeData>(
coord: DemuxCoord,
address: (String, u16),
Expand Down Expand Up @@ -161,7 +161,7 @@ async fn bind_remotes<In: ExchangeData>(
/// + Return an enum, either Queued or Overflowed
///
/// if overflowed send a yield request through a second channel
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
async fn demux_thread<In: ExchangeData>(
coord: DemuxCoord,
senders: HashMap<ReceiverEndpoint, Sender<NetworkMessage<In>>>,
Expand Down
18 changes: 9 additions & 9 deletions src/network/tokio/multiplexer.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
use std::io::ErrorKind;
use std::time::Duration;

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use std::net::ToSocketAddrs;
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::net::TcpStream;
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::task::JoinHandle;
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::time::sleep;

use crate::channel::{self, Receiver, Sender};
use crate::network::remote::remote_send;
use crate::network::{DemuxCoord, NetworkMessage, ReceiverEndpoint};
use crate::operator::ExchangeData;

// #[cfg(not(feature = "async-tokio"))]
// #[cfg(not(feature = "tokio"))]
// use crate::channel::Selector;

use crate::network::NetworkSender;

/// Maximum number of attempts to make for connecting to a remote host.
const CONNECT_ATTEMPTS: usize = 32;
/// Timeout for connecting to a remote host.
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
/// To avoid spamming the connections, wait this timeout before trying again. If the connection
/// fails again this timeout will be doubled up to `RETRY_MAX_TIMEOUT`.
Expand All @@ -40,7 +40,7 @@ pub struct MultiplexingSender<Out: Send + 'static> {
tx: Option<Sender<(ReceiverEndpoint, NetworkMessage<Out>)>>,
}

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
impl<Out: ExchangeData> MultiplexingSender<Out> {
/// Construct a new `MultiplexingSender` for a block.
///
Expand Down Expand Up @@ -81,7 +81,7 @@ impl<Out: ExchangeData> MultiplexingSender<Out> {
/// - Then at most `CONNECT_ATTEMPTS` are performed, and an exponential backoff is used in case
/// of errors.
/// - If the connection cannot be established this function will panic.
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
async fn connect_remote(coord: DemuxCoord, address: (String, u16)) -> TcpStream {
let socket_addrs: Vec<_> = address
.to_socket_addrs()
Expand Down Expand Up @@ -137,7 +137,7 @@ async fn connect_remote(coord: DemuxCoord, address: (String, u16)) -> TcpStream
);
}

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
async fn mux_thread<Out: ExchangeData>(
coord: DemuxCoord,
rx: Receiver<(ReceiverEndpoint, NetworkMessage<Out>)>,
Expand Down
6 changes: 3 additions & 3 deletions src/network/tokio/remote.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use once_cell::sync::Lazy;
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use bincode::config::{FixintEncoding, RejectTrailing, WithOtherIntEncoding, WithOtherTrailing};
Expand Down Expand Up @@ -41,7 +41,7 @@ struct MessageHeader {
/// The network protocol works as follow:
/// - send a `MessageHeader` serialized with bincode with `FixintEncoding`
/// - send the message
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
pub(crate) async fn remote_send<T: ExchangeData, W: AsyncWrite + Unpin>(
msg: NetworkMessage<T>,
dest: ReceiverEndpoint,
Expand Down Expand Up @@ -98,7 +98,7 @@ pub(crate) async fn remote_send<T: ExchangeData, W: AsyncWrite + Unpin>(
);
}

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
pub(crate) async fn remote_recv<T: ExchangeData, R: AsyncRead + Unpin>(
coord: DemuxCoord,
reader: &mut R,
Expand Down
28 changes: 14 additions & 14 deletions src/network/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::marker::PhantomData;
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
use std::thread::JoinHandle;

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use futures::StreamExt;
use itertools::Itertools;
use typemap_rev::{TypeMap, TypeMapKey};
Expand Down Expand Up @@ -119,10 +119,10 @@ pub(crate) struct NetworkTopology {
demultiplexer_addresses: HashMap<DemuxCoord, (String, u16), crate::block::CoordHasherBuilder>,

/// The set of join handles of the various threads spawned by the topology.
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
join_handles: Vec<JoinHandle<()>>,

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
async_join_handles: Vec<tokio::task::JoinHandle<()>>,
}

Expand Down Expand Up @@ -153,14 +153,14 @@ impl NetworkTopology {
used_receivers: Default::default(),
registered_receivers: Default::default(),
demultiplexer_addresses: Default::default(),
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
join_handles: Default::default(),
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
async_join_handles: Default::default(),
}
}

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
/// Knowing that the computation ended, tear down the topology wait for all of its thread to
/// exit.
pub(crate) async fn stop_and_wait(&mut self) {
Expand All @@ -174,7 +174,7 @@ impl NetworkTopology {
.await;
}

#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
/// Knowing that the computation ended, tear down the topology wait for all of its thread to
/// exit.
pub(crate) fn stop_and_wait(&mut self) {
Expand Down Expand Up @@ -314,9 +314,9 @@ impl NetworkTopology {
if !prev.is_empty() {
let address = self.demultiplexer_addresses[&demux_coord].clone();
let (demux, join_handle) = DemuxHandle::new(demux_coord, address, prev.len());
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
self.join_handles.push(join_handle);
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
self.async_join_handles.push(join_handle);
e.insert(demux);
} else {
Expand All @@ -343,9 +343,9 @@ impl NetworkTopology {
if let Entry::Vacant(e) = muxers.entry(demux_coord) {
let address = self.demultiplexer_addresses[&demux_coord].clone();
let (mux, join_handle) = MultiplexingSender::new(demux_coord, address);
#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
self.join_handles.push(join_handle);
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
self.async_join_handles.push(join_handle);
e.insert(mux);
}
Expand Down Expand Up @@ -627,7 +627,7 @@ mod tests {
);
}

#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
#[test]
fn test_remote_topology() {
let mut config = tempfile::NamedTempFile::new().unwrap();
Expand Down Expand Up @@ -766,7 +766,7 @@ mod tests {
join1.join().unwrap();
}

#[cfg(not(feature = "async-tokio"))]
#[cfg(not(feature = "tokio"))]
fn receiver<T: ExchangeData + Ord + std::fmt::Debug>(
receiver: NetworkReceiver<T>,
expected: Vec<T>,
Expand Down
15 changes: 7 additions & 8 deletions src/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::hash::Hash;
use std::ops::{AddAssign, Div};

use flume::{unbounded, Receiver};
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use futures::Future;
use serde::{Deserialize, Serialize};

Expand All @@ -22,7 +22,7 @@ use crate::scheduler::ExecutionMetadata;
use crate::stream::KeyedItem;
use crate::{BatchMode, KeyedStream, Stream};

#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
use self::map_async::MapAsync;
use self::map_memo::MapMemo;
use self::sink::collect::Collect;
Expand Down Expand Up @@ -72,7 +72,7 @@ pub mod join;
mod key_by;
mod keyed_fold;
mod map;
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
mod map_async;
mod map_memo;
mod merge;
Expand Down Expand Up @@ -200,7 +200,7 @@ impl<Out> StreamElement<Out> {
}

/// Change the type of the element inside the `StreamElement`.
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
pub async fn map_async<NewOut, F, Fut>(self, f: F) -> StreamElement<NewOut>
where
F: FnOnce(Out) -> Fut,
Expand Down Expand Up @@ -573,7 +573,7 @@ where
/// assert_eq!(res.get().unwrap(), vec![4, 1, 0, 1, 4, 2, 2, 4, 1, 0]);
/// # }
/// ```
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
pub fn map_async_memo_by<O, K, F, Fk, Fut>(
self,
f: F,
Expand All @@ -587,7 +587,6 @@ where
O: Clone + Send + Sync + 'static,
K: DataKey + Sync,
{
use crate::block::GroupHasherBuilder;
use futures::FutureExt;
use quick_cache::{sync::Cache, UnitWeighter};
use std::{convert::Infallible, sync::Arc};
Expand Down Expand Up @@ -637,7 +636,7 @@ where
/// assert_eq!(res.get().unwrap(), vec![4, 1, 0, 1, 4, 2, 2, 4, 1, 0]);
/// # }
/// ```
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
pub fn map_async<O: Data, F, Fut>(self, f: F) -> Stream<impl Operator<Out = O>>
where
F: Fn(Op::Out) -> Fut + Send + Sync + 'static + Clone,
Expand Down Expand Up @@ -2011,7 +2010,7 @@ where
/// assert_eq!(res.get().unwrap(), vec![0, 1, 4, 9, 0, 1, 4, 9, 0, 1]);
/// # }
/// ```
#[cfg(feature = "async-tokio")]
#[cfg(feature = "tokio")]
pub fn map_async_memo<O: Clone + Send + Sync + 'static, F, Fut>(
self,
f: F,
Expand Down
Loading

0 comments on commit 17a2b60

Please sign in to comment.