Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement simple atomic stream select #585

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions lib_eio/stream.ml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,45 @@ module Locking = struct
Mutex.unlock t.mutex;
Some v

let select_of_many streams_fns =
let finished = Atomic.make false in
let cancel_fns = ref [] in
let add_cancel_fn fn = cancel_fns := fn :: !cancel_fns in
let cancel_all () = List.iter (fun fn -> fn ()) !cancel_fns in
let wait ctx enqueue (t, f) = begin
Mutex.lock t.mutex;
(* First check if any items are already available and return early if there are. *)
if not (Queue.is_empty t.items)
then (
cancel_all ();
Mutex.unlock t.mutex;
enqueue (Ok (f (Queue.take t.items))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another domain could still have set finished to true using a different stream by the time this runs.

(also, you can't use Queue.take after unlocking)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for missing this, it should have been obvious! Please take another look, it seemed not too difficult to fix.

else add_cancel_fn @@
(* Otherwise, register interest in this stream. *)
Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r ->
if Result.is_ok r then (
if not (Atomic.compare_and_set finished false true) then (
(* Another stream has yielded an item in the meantime. However, as
we have been waiting on this stream it must have been empty.

As the stream's mutex was held since before last checking for an item,
the queue must be empty.
*)
assert ((Queue.length t.items) < t.capacity);
Queue.add (Result.get_ok r) t.items
) else (
(* remove all other entries of this fiber in other streams' waiters. *)
cancel_all ()
));
(* item is returned to waiting caller through enqueue and enter_unchecked. *)
enqueue (Result.map f r))
end in
(* Register interest in all streams and return first available item. *)
let wait_for_stream streams_fns = begin
Suspend.enter_unchecked (fun ctx enqueue -> List.iter (wait ctx enqueue) streams_fns)
end in
wait_for_stream streams_fns

let length t =
Mutex.lock t.mutex;
let len = Queue.length t.items in
Expand Down Expand Up @@ -125,6 +164,13 @@ let take_nonblocking = function
| Sync x -> Sync.take_nonblocking x
| Locking x -> Locking.take_nonblocking x

let select streams =
let filter s = match s with
| (Sync _, _) -> assert false
| (Locking x, f) -> (x, f)
in
Locking.select_of_many (List.map filter streams)

let length = function
| Sync _ -> 0
| Locking x -> Locking.length x
Expand Down
4 changes: 4 additions & 0 deletions lib_eio/stream.mli
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ val take_nonblocking : 'a t -> 'a option
Note that if another domain may add to the stream then a [None]
result may already be out-of-date by the time this returns. *)

val select : ('a t * ('a -> 'b)) list -> 'b
(** [select] returns the first item yielded by any stream. This only
works for streams with non-zero capacity. *)

val length : 'a t -> int
(** [length t] returns the number of items currently in [t]. *)

Expand Down
19 changes: 15 additions & 4 deletions lib_eio/waiters.ml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ let rec wake_one t v =

let is_empty = Lwt_dllist.is_empty

let await_internal ~mutex (t:'a t) id ctx enqueue =
let cancellable_await_internal ~mutex (t:'a t) id ctx enqueue =
match Fiber_context.get_error ctx with
| Some ex ->
Option.iter Mutex.unlock mutex;
enqueue (Error ex)
enqueue (Error ex);
fun () -> ()
| None ->
let resolved_waiter = ref Hook.null in
let finished = Atomic.make false in
Expand All @@ -56,14 +57,24 @@ let await_internal ~mutex (t:'a t) id ctx enqueue =
enqueue (Error ex)
)
in
let unwait () =
if Atomic.compare_and_set finished false true
then Hook.remove !resolved_waiter
in
Fiber_context.set_cancel_fn ctx cancel;
let waiter = { enqueue; finished } in
match mutex with
| None ->
resolved_waiter := add_waiter t waiter
resolved_waiter := add_waiter t waiter;
unwait
| Some mutex ->
resolved_waiter := add_waiter_protected ~mutex t waiter;
Mutex.unlock mutex
Mutex.unlock mutex;
unwait

let await_internal ~mutex (t: 'a t) id ctx enqueue =
let _cancel = (cancellable_await_internal ~mutex t id ctx enqueue) in
()

(* Returns a result if the wait succeeds, or raises if cancelled. *)
let await ~mutex waiters id =
Expand Down
13 changes: 11 additions & 2 deletions lib_eio/waiters.mli
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ val await :
If [t] can be used from multiple domains:
- [mutex] must be set to the mutex to use to unlock it.
- [mutex] must be already held when calling this function, which will unlock it before blocking.
When [await] returns, [mutex] will have been unlocked.
@raise Cancel.Cancelled if the fiber's context is cancelled *)
When [await] returns, [mutex] will have been unlocked.
@raise Cancel.Cancelled if the fiber's context is cancelled *)

val await_internal :
mutex:Mutex.t option ->
Expand All @@ -40,3 +40,12 @@ val await_internal :
Note: [enqueue] is called from the triggering domain,
which is currently calling {!wake_one} or {!wake_all}
and must therefore be holding [mutex]. *)

val cancellable_await_internal :
mutex:Mutex.t option ->
'a t -> Ctf.id -> Fiber_context.t ->
(('a, exn) result -> unit) -> (unit -> unit)
(** Like [await_internal], but returns a function which, when called,
removes the current fiber continuation from the waiters list.
This is used when a fiber is waiting for multiple [Waiter]s simultaneously,
and needs to remove itself from other waiters once it has been enqueued by one.*)
19 changes: 19 additions & 0 deletions tests/stream.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,22 @@ Non-blocking take with zero-capacity stream:
+Got None from stream
- : unit = ()
```

Selecting from multiple channels:

```ocaml
# run @@ fun () -> Switch.run (fun sw ->
let t1, t2 = (S.create 2), (S.create 2) in
let selector = [(t1, fun x -> x); (t2, fun x -> x)] in
Fiber.fork ~sw (fun () -> S.add t2 "foo");
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> S.add t2 "bar");
Fiber.fork ~sw (fun () -> S.add t1 "baz");
)
+foo
+bar
+baz
- : unit = ()
```