Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement simple atomic stream select #585

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions bench/bench_select.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

open Eio.Stdenv
open Eio

let sender_fibers = 4
let cap = 10

let message = 1234

(* Send [n_msgs] items to streams in a round-robin way. *)
let sender ~n_msgs streams =
let msgs = Seq.take n_msgs (Seq.ints 0) in
let streams = Seq.cycle (List.to_seq streams) in
let zipped = Seq.zip msgs streams in
ignore (Seq.iter (fun (_i, stream) ->
Stream.add stream message) zipped)

(* Start one sender fiber for each stream, and let it send n_msgs messages.
Each fiber sends to all streams in a round-robin way. *)
let run_senders ~dom_mgr ?(n_msgs = 100) streams =
Switch.run @@ fun sw ->
ignore @@ List.iter (fun _stream ->
Fiber.fork ~sw (fun () ->
Domain_manager.run dom_mgr (fun () ->
sender ~n_msgs streams))) streams

(* Receive messages from all streams. *)
let receiver ~n_msgs streams =
for _i = 1 to n_msgs do
assert (Int.equal message (Stream.select streams));
done

(* Create [n] streams. *)
let make_streams cap n =
let unfolder i = if i == 0 then None else Some (Stream.create cap, i-1) in
let seq = Seq.unfold unfolder n in
List.of_seq seq

let run env =
let dom_mgr = domain_mgr env in
let clock = clock env in
let streams = make_streams cap sender_fibers in
let selector = List.map (fun s -> (s, fun i -> i)) streams in
let n_msgs = 10000 in
Switch.run @@ fun sw ->
Fiber.fork ~sw (fun () -> run_senders ~dom_mgr ~n_msgs streams);
let before = Time.now clock in
receiver ~n_msgs:(sender_fibers * n_msgs) selector;
let after = Time.now clock in
let elapsed = after -. before in
let time_per_iter = elapsed /. (Float.of_int @@ sender_fibers * n_msgs) in
[Metric.create
(Printf.sprintf "sync:true senders:%d msgs_per_sender:%d" sender_fibers n_msgs)
(`Float (1e9 *. time_per_iter)) "ns"
"Time per transmitted int"]


1 change: 1 addition & 0 deletions bench/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ let benchmarks = [
"Stream", Bench_stream.run;
"HTTP", Bench_http.run;
"Eio_unix.Fd", Bench_fd.run;
"StreamSelect", Bench_select.run;
]

let usage_error () =
Expand Down
54 changes: 54 additions & 0 deletions lib_eio/stream.ml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,53 @@ module Locking = struct
Mutex.unlock t.mutex;
Some v

let select_of_many streams_fns =
let finished = Atomic.make false in
let cancel_fns = ref [] in
let add_cancel_fn fn = cancel_fns := fn :: !cancel_fns in
let cancel_all () = List.iter (fun fn -> fn ()) !cancel_fns in
let wait ctx enqueue (t, f) = begin
Mutex.lock t.mutex;
(* First check if any items are already available and return early if there are. *)
if not (Queue.is_empty t.items)
then (
(* If no other stream has yielded already, we are the first one. *)
if Atomic.compare_and_set finished false true
then (
(* Therefore, cancel all other waiters and take available item. *)
cancel_all ();
let item = Queue.take t.items in
ignore (Waiters.wake_one t.writers ());
enqueue (Ok (f item)));
Mutex.unlock t.mutex
)
else add_cancel_fn @@
(* Otherwise, register interest in this stream. *)
Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r ->
if Result.is_ok r then (
if not (Atomic.compare_and_set finished false true) then (
(* Another stream has yielded an item in the meantime. However, as
we have been waiting on this stream it must have been empty.

As the stream's mutex was held since before last checking for an item,
the queue must be empty.
*)
assert ((Queue.length t.items) < t.capacity);
Queue.add (Result.get_ok r) t.items
) else (
(* remove all other entries of this fiber in other streams' waiters. *)
ignore (Waiters.wake_one t.writers ());
cancel_all ();
(* item is returned to waiting caller through enqueue and enter_unchecked. *)
enqueue (Result.map f r))
));
end in
(* Register interest in all streams and return first available item. *)
let wait_for_stream streams_fns = begin
Suspend.enter_unchecked (fun ctx enqueue -> List.iter (wait ctx enqueue) streams_fns)
end in
wait_for_stream streams_fns

let length t =
Mutex.lock t.mutex;
let len = Queue.length t.items in
Expand Down Expand Up @@ -125,6 +172,13 @@ let take_nonblocking = function
| Sync x -> Sync.take_nonblocking x
| Locking x -> Locking.take_nonblocking x

let select streams =
let filter s = match s with
| (Sync _, _) -> assert false
| (Locking x, f) -> (x, f)
in
Locking.select_of_many (List.map filter streams)

let length = function
| Sync _ -> 0
| Locking x -> Locking.length x
Expand Down
4 changes: 4 additions & 0 deletions lib_eio/stream.mli
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ val take_nonblocking : 'a t -> 'a option
Note that if another domain may add to the stream then a [None]
result may already be out-of-date by the time this returns. *)

val select : ('a t * ('a -> 'b)) list -> 'b
(** [select] returns the first item yielded by any stream. This only
works for streams with non-zero capacity. *)

val length : 'a t -> int
(** [length t] returns the number of items currently in [t]. *)

Expand Down
19 changes: 15 additions & 4 deletions lib_eio/waiters.ml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ let rec wake_one t v =

let is_empty = Lwt_dllist.is_empty

let await_internal ~mutex (t:'a t) id ctx enqueue =
let cancellable_await_internal ~mutex (t:'a t) id ctx enqueue =
match Fiber_context.get_error ctx with
| Some ex ->
Option.iter Mutex.unlock mutex;
enqueue (Error ex)
enqueue (Error ex);
fun () -> ()
| None ->
let resolved_waiter = ref Hook.null in
let finished = Atomic.make false in
Expand All @@ -56,14 +57,24 @@ let await_internal ~mutex (t:'a t) id ctx enqueue =
enqueue (Error ex)
)
in
let unwait () =
if Atomic.compare_and_set finished false true
then Hook.remove !resolved_waiter
in
Fiber_context.set_cancel_fn ctx cancel;
let waiter = { enqueue; finished } in
match mutex with
| None ->
resolved_waiter := add_waiter t waiter
resolved_waiter := add_waiter t waiter;
unwait
| Some mutex ->
resolved_waiter := add_waiter_protected ~mutex t waiter;
Mutex.unlock mutex
Mutex.unlock mutex;
unwait

let await_internal ~mutex (t: 'a t) id ctx enqueue =
let _cancel = (cancellable_await_internal ~mutex t id ctx enqueue) in
()

(* Returns a result if the wait succeeds, or raises if cancelled. *)
let await ~mutex waiters id =
Expand Down
13 changes: 11 additions & 2 deletions lib_eio/waiters.mli
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ val await :
If [t] can be used from multiple domains:
- [mutex] must be set to the mutex to use to unlock it.
- [mutex] must be already held when calling this function, which will unlock it before blocking.
When [await] returns, [mutex] will have been unlocked.
@raise Cancel.Cancelled if the fiber's context is cancelled *)
When [await] returns, [mutex] will have been unlocked.
@raise Cancel.Cancelled if the fiber's context is cancelled *)

val await_internal :
mutex:Mutex.t option ->
Expand All @@ -40,3 +40,12 @@ val await_internal :
Note: [enqueue] is called from the triggering domain,
which is currently calling {!wake_one} or {!wake_all}
and must therefore be holding [mutex]. *)

val cancellable_await_internal :
mutex:Mutex.t option ->
'a t -> Ctf.id -> Fiber_context.t ->
(('a, exn) result -> unit) -> (unit -> unit)
(** Like [await_internal], but returns a function which, when called,
removes the current fiber continuation from the waiters list.
This is used when a fiber is waiting for multiple [Waiter]s simultaneously,
and needs to remove itself from other waiters once it has been enqueued by one.*)
19 changes: 19 additions & 0 deletions tests/stream.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,22 @@ Non-blocking take with zero-capacity stream:
+Got None from stream
- : unit = ()
```

Selecting from multiple channels:

```ocaml
# run @@ fun () -> Switch.run (fun sw ->
let t1, t2 = (S.create 2), (S.create 2) in
let selector = [(t1, fun x -> x); (t2, fun x -> x)] in
Fiber.fork ~sw (fun () -> S.add t2 "foo");
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> S.add t2 "bar");
Fiber.fork ~sw (fun () -> S.add t1 "baz");
)
+foo
+bar
+baz
- : unit = ()
```