diff --git a/src/Dagger.jl b/src/Dagger.jl index afbc37b9..70725340 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -71,9 +71,9 @@ include("sch/Sch.jl"); using .Sch include("datadeps.jl") # Streaming +include("stream.jl") include("stream-buffers.jl") include("stream-transfer.jl") -include("stream.jl") # Array computations include("array/darray.jl") diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl index cd00000b..579a5975 100644 --- a/src/stream-buffers.jl +++ b/src/stream-buffers.jl @@ -1,50 +1,35 @@ """ -A buffer that drops all elements put into it. Only to be used as the output -buffer for a task - will throw if attached as an input. +A buffer that drops all elements put into it. """ -struct DropBuffer{T} end +mutable struct DropBuffer{T} + open::Bool + DropBuffer{T}() where T = new{T}(true) +end DropBuffer{T}(_) where T = DropBuffer{T}() Base.isempty(::DropBuffer) = true isfull(::DropBuffer) = false -Base.put!(::DropBuffer, _) = nothing -Base.take!(::DropBuffer) = error("Cannot `take!` from a DropBuffer") - -"A process-local buffer backed by a `Channel{T}`." -struct ChannelBuffer{T} - channel::Channel{T} - len::Int - count::Threads.Atomic{Int} - ChannelBuffer{T}(len::Int=1024) where T = - new{T}(Channel{T}(len), len, Threads.Atomic{Int}(0)) -end -Base.isempty(cb::ChannelBuffer) = isempty(cb.channel) -isfull(cb::ChannelBuffer) = cb.count[] == cb.len -function Base.put!(cb::ChannelBuffer{T}, x) where T - put!(cb.channel, convert(T, x)) - Threads.atomic_add!(cb.count, 1) -end -function Base.take!(cb::ChannelBuffer) - take!(cb.channel) - Threads.atomic_sub!(cb.count, 1) -end - -"A cross-worker buffer backed by a `RemoteChannel{T}`." -struct RemoteChannelBuffer{T} - channel::RemoteChannel{Channel{T}} - len::Int - count::Threads.Atomic{Int} - RemoteChannelBuffer{T}(len::Int=1024) where T = - new{T}(RemoteChannel(()->Channel{T}(len)), len, Threads.Atomic{Int}(0)) -end -Base.isempty(cb::RemoteChannelBuffer) = isempty(cb.channel) -isfull(cb::RemoteChannelBuffer) = cb.count[] == cb.len -function Base.put!(cb::RemoteChannelBuffer{T}, x) where T - put!(cb.channel, convert(T, x)) - Threads.atomic_add!(cb.count, 1) -end -function Base.take!(cb::RemoteChannelBuffer) - take!(cb.channel) - Threads.atomic_sub!(cb.count, 1) +capacity(::DropBuffer) = typemax(Int) +Base.length(::DropBuffer) = 0 +Base.isopen(buf::DropBuffer) = buf.open +function Base.close(buf::DropBuffer) + buf.open = false +end +function Base.put!(buf::DropBuffer, _) + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + return +end +function Base.take!(buf::DropBuffer) + while true + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + end end "A process-local ring buffer." @@ -53,7 +38,7 @@ mutable struct ProcessRingBuffer{T} write_idx::Int @atomic count::Int buffer::Vector{T} - open::Bool + @atomic open::Bool function ProcessRingBuffer{T}(len::Int=1024) where T buffer = Vector{T}(undef, len) return new{T}(1, 1, 0, buffer, true) @@ -61,32 +46,37 @@ mutable struct ProcessRingBuffer{T} end Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +capacity(rb::ProcessRingBuffer) = length(rb.buffer) Base.length(rb::ProcessRingBuffer) = @atomic rb.count -Base.isopen(rb::ProcessRingBuffer) = rb.open +Base.isopen(rb::ProcessRingBuffer) = @atomic rb.open function Base.close(rb::ProcessRingBuffer) - rb.open = false + @atomic rb.open = false end function Base.put!(rb::ProcessRingBuffer{T}, x) where T - len = length(rb.buffer) - while (@atomic rb.count) == len + while isfull(rb) yield() if !isopen(rb) - throw(InvalidStateException("Stream is closed", :closed)) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) end - task_may_cancel!() + task_may_cancel!(; must_force=true) end - to_write_idx = mod1(rb.write_idx, len) + to_write_idx = mod1(rb.write_idx, length(rb.buffer)) rb.buffer[to_write_idx] = convert(T, x) rb.write_idx += 1 @atomic rb.count += 1 end function Base.take!(rb::ProcessRingBuffer) - while (@atomic rb.count) == 0 + while isempty(rb) yield() - if !isopen(rb) - throw(InvalidStateException("Stream is closed", :closed)) + if !isopen(rb) && isempty(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) end - task_may_cancel!() + if task_cancelled() && isempty(rb) + # We respect a graceful cancellation only if the buffer is empty. + # Otherwise, we may have values to continue communicating. + task_may_cancel!() + end + task_may_cancel!(; must_force=true) end to_read_idx = rb.read_idx rb.read_idx += 1 @@ -106,123 +96,3 @@ function collect!(rb::ProcessRingBuffer{T}) where T return output end - -#= TODO -"A server-local ring buffer backed by shared-memory." -mutable struct ServerRingBuffer{T} - read_idx::Int - write_idx::Int - @atomic count::Int - buffer::Vector{T} - function ServerRingBuffer{T}(len::Int=1024) where T - buffer = Vector{T}(undef, len) - return new{T}(1, 1, 0, buffer) - end -end -Base.isempty(rb::ServerRingBuffer) = (@atomic rb.count) == 0 -function Base.put!(rb::ServerRingBuffer{T}, x) where T - len = length(rb.buffer) - while (@atomic rb.count) == len - yield() - end - to_write_idx = mod1(rb.write_idx, len) - rb.buffer[to_write_idx] = convert(T, x) - rb.write_idx += 1 - @atomic rb.count += 1 -end -function Base.take!(rb::ServerRingBuffer) - while (@atomic rb.count) == 0 - yield() - end - to_read_idx = rb.read_idx - rb.read_idx += 1 - @atomic rb.count -= 1 - to_read_idx = mod1(to_read_idx, length(rb.buffer)) - return rb.buffer[to_read_idx] -end -=# - -#= -"A TCP-based ring buffer." -mutable struct TCPRingBuffer{T} - read_idx::Int - write_idx::Int - @atomic count::Int - buffer::Vector{T} - function TCPRingBuffer{T}(len::Int=1024) where T - buffer = Vector{T}(undef, len) - return new{T}(1, 1, 0, buffer) - end -end -Base.isempty(rb::TCPRingBuffer) = (@atomic rb.count) == 0 -function Base.put!(rb::TCPRingBuffer{T}, x) where T - len = length(rb.buffer) - while (@atomic rb.count) == len - yield() - end - to_write_idx = mod1(rb.write_idx, len) - rb.buffer[to_write_idx] = convert(T, x) - rb.write_idx += 1 - @atomic rb.count += 1 -end -function Base.take!(rb::TCPRingBuffer) - while (@atomic rb.count) == 0 - yield() - end - to_read_idx = rb.read_idx - rb.read_idx += 1 - @atomic rb.count -= 1 - to_read_idx = mod1(to_read_idx, length(rb.buffer)) - return rb.buffer[to_read_idx] -end -=# - -#= -""" -A flexible puller which switches to the most efficient buffer type based -on the sender and receiver locations. -""" -mutable struct UniBuffer{T} - buffer::Union{ProcessRingBuffer{T}, Nothing} -end -function initialize_stream_buffer!(::Type{UniBuffer{T}}, T, send_proc, recv_proc, buffer_amount) where T - if buffer_amount == 0 - error("Return NullBuffer") - end - send_osproc = get_parent(send_proc) - recv_osproc = get_parent(recv_proc) - if send_osproc.pid == recv_osproc.pid - inner = RingBuffer{T}(buffer_amount) - elseif system_uuid(send_osproc.pid) == system_uuid(recv_osproc.pid) - inner = ProcessBuffer{T}(buffer_amount) - else - inner = RemoteBuffer{T}(buffer_amount) - end - return UniBuffer{T}(buffer_amount) -end - -struct LocalPuller{T,B} - buffer::B{T} - id::UInt - function LocalPuller{T,B}(id::UInt, buffer_amount::Integer) where {T,B} - buffer = initialize_stream_buffer!(B, T, buffer_amount) - return new{T,B}(buffer, id) - end -end -function Base.take!(pull::LocalPuller{T,B}) where {T,B} - if pull.buffer === nothing - pull.buffer = - error("Return NullBuffer") - end - value = take!(pull.buffer) -end -function initialize_input_stream!(stream::Stream{T,B}, id::UInt, send_proc::Processor, recv_proc::Processor, buffer_amount::Integer) where {T,B} - local_buffer = remotecall_fetch(stream.ref.handle.owner, stream.ref.handle, id) do ref, id - local_buffer, remote_buffer = initialize_stream_buffer!(B, T, send_proc, recv_proc, buffer_amount) - ref.buffers[id] = remote_buffer - return local_buffer - end - stream.buffer = local_buffer - return stream -end -=# diff --git a/src/stream-transfer.jl b/src/stream-transfer.jl index 3251abb9..66780876 100644 --- a/src/stream-transfer.jl +++ b/src/stream-transfer.jl @@ -1,32 +1,116 @@ +struct RemoteChannelFetcher + chan::RemoteChannel + RemoteChannelFetcher() = new(RemoteChannel()) +end +const _THEIR_TID = TaskLocalValue{Int}(()->0) +function stream_push_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_push "taking output value: $our_tid -> $their_tid" + value = try + take!(buffer) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_push "pushing output value: $our_tid -> $their_tid" + try + put!(fetcher.chan, value) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_push "channel closed: $our_tid -> $their_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_push "finished pushing output value: $our_tid -> $their_tid" +end +function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_pull "pulling input value: $their_tid -> $our_tid" + value = try + take!(fetcher.chan) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_pull "channel closed: $their_tid -> $our_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_pull "putting input value: $their_tid -> $our_tid" + try + put!(buffer, value) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_pull "finished putting input value: $their_tid -> $our_tid" +end + +#= TODO: Remove me +# This is a bad implementation because it wants to sleep on the remote side to +# wait for values, but this isn't semantically valid when done with MemPool.access_ref struct RemoteFetcher end -# TODO: Switch to RemoteChannel approach -function stream_pull_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} +function stream_push_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + sleep(1) +end +function stream_pull_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + id = our_store.uid thunk_id = STREAM_THUNK_ID[] @dagdebug thunk_id :stream "fetching values" - free_space = length(buffer.buffer) - length(buffer) + free_space = capacity(buffer) - length(buffer) if free_space == 0 + @dagdebug thunk_id :stream "waiting for drain of full input buffer" yield() task_may_cancel!() + wait_for_nonfull_input(our_store, their_stream.uid) return end values = T[] while isempty(values) - values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id, free_space) do store, id, T, Store_remote, thunk_id, free_space - @dagdebug thunk_id :stream "trying to fetch values at $(myid())" - store::Store_remote - in_store = store + values, closed = MemPool.access_ref(their_stream.store_ref.handle, id, T, thunk_id, free_space) do their_store, id, T, thunk_id, free_space + @dagdebug thunk_id :stream "trying to fetch values at worker $(myid())" STREAM_THUNK_ID[] = thunk_id values = T[] - @dagdebug thunk_id :stream "trying to fetch: $(store.output_buffers[id].count) values, free_space: $free_space" - while !isempty(store, id) && length(values) < free_space - value = take!(store, id)::T + @dagdebug thunk_id :stream "trying to fetch with free_space: $free_space" + wait_for_nonempty_output(their_store, id) + if isempty(their_store, id) && !isopen(their_store, id) + @dagdebug thunk_id :stream "remote stream is closed, returning" + return values, true + end + while !isempty(their_store, id) && length(values) < free_space + value = take!(their_store, id)::T @dagdebug thunk_id :stream "fetched $value" push!(values, value) end - return values - end::Vector{T} + return values, false + end::Tuple{Vector{T},Bool} + if closed + throw(InterruptException()) + end # We explicitly yield in the loop to allow other tasks to run. This # matters on single-threaded instances because MemPool.access_ref() @@ -41,6 +125,4 @@ function stream_pull_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_re put!(buffer, value) end end -function stream_push_values!(::Type{RemoteFetcher}, T, store_ref::Store_remote, buffer::Blocal, id::UInt) where {Store_remote, Blocal} - sleep(1) -end +=# diff --git a/src/stream.jl b/src/stream.jl index 493183b9..f3c2cfbc 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -7,6 +7,8 @@ mutable struct StreamStore{T,B} output_buffers::Dict{UInt,B} input_buffer_amount::Int output_buffer_amount::Int + input_fetchers::Dict{UInt,Any} + output_fetchers::Dict{UInt,Any} open::Bool migrating::Bool lock::Threads.Condition @@ -15,6 +17,7 @@ mutable struct StreamStore{T,B} Dict{UInt,Any}(), Dict{UInt,Any}(), Dict{UInt,B}(), Dict{UInt,B}(), input_buffer_amount, output_buffer_amount, + Dict{UInt,Any}(), Dict{UInt,Any}(), true, false, Threads.Condition()) end @@ -48,6 +51,9 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} end @dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting" wait(store.lock) + if !isfull(buffer) + @dagdebug thunk_id :stream "buffer has space ($(length(buffer)) values), continuing" + end task_may_cancel!() end put!(buffer, value) @@ -85,6 +91,36 @@ function Base.take!(store::StreamStore, id::UInt) return value end end +function wait_for_nonfull_input(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.input_streams, id) + @assert haskey(store.input_buffers, id) + buffer = store.input_buffers[id] + while isfull(buffer) && isopen(store) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for space in input buffer" + wait(store.lock) + end + end +end +function wait_for_nonempty_output(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.output_streams, id) + + # Wait for the output buffer to be initialized + while !haskey(store.output_buffers, id) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be initialized" + wait(store.lock) + end + isopen(store, id) || return + + # Wait for the output buffer to be nonempty + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be nonempty" + wait(store.lock) + end + end +end function Base.isempty(store::StreamStore, id::UInt) if !haskey(store.output_buffers, id) @@ -105,6 +141,10 @@ taken. """ function Base.isopen(store::StreamStore, id::UInt) @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return store.open + end if !isempty(store.output_buffers[id]) return true end @@ -127,13 +167,14 @@ function Base.close(store::StreamStore) end # FIXME: Just pass Stream directly, rather than its uid -function add_waiters!(store::StreamStore{T,B}, waiters::Vector{UInt}) where {T,B} +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Pair{UInt,Any}}) where {T,B} our_uid = store.uid @lock store.lock begin - for output_uid in waiters + for (output_uid, output_fetcher) in waiters store.output_streams[output_uid] = task_to_stream(output_uid) + push!(store.waiters, output_uid) + store.output_fetchers[output_uid] = output_fetcher end - append!(store.waiters, waiters) notify(store.lock) end end @@ -175,7 +216,8 @@ Base.take!(sv::StreamingValue) = take!(sv.buffer) function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::Stream{IT,IB}) where {IT,OT,IB,OB} input_uid = input_stream.uid our_uid = our_store.uid - buffer = @lock our_store.lock begin + local buffer, input_fetcher + @lock our_store.lock begin if haskey(our_store.input_buffers, input_uid) return StreamingValue(our_store.input_buffers[input_uid]) end @@ -183,7 +225,7 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S buffer = initialize_stream_buffer(OB, IT, our_store.input_buffer_amount) # FIXME: Also pass a RemoteChannel to track remote closure our_store.input_buffers[input_uid] = buffer - buffer + input_fetcher = our_store.input_fetchers[input_uid] end thunk_id = STREAM_THUNK_ID[] tls = get_tls() @@ -192,11 +234,14 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S STREAM_THUNK_ID[] = thunk_id try while isopen(our_store) - # FIXME: Make remote fetcher configurable - stream_pull_values!(RemoteFetcher, IT, input_stream.store_ref, buffer, our_uid) + stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) end catch err - err isa InterruptException || rethrow(err) + if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow(err) + end finally @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" end @@ -204,23 +249,34 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S return StreamingValue(buffer) end initialize_input_stream!(our_store::StreamStore, arg) = arg -function initialize_output_stream!(store::StreamStore{T,B}, output_uid::UInt) where {T,B} - @assert islocked(store.lock) +function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} + @assert islocked(our_store.lock) @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" - buffer = initialize_stream_buffer(B, T, store.output_buffer_amount) - store.output_buffers[output_uid] = buffer - our_uid = store.uid + buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) + our_store.output_buffers[output_uid] = buffer + our_uid = our_store.uid + output_stream = our_store.output_streams[output_uid] + output_fetcher = our_store.output_fetchers[output_uid] thunk_id = STREAM_THUNK_ID[] tls = get_tls() Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id try - while isopen(store) - # FIXME: Make remote fetcher configurable - stream_push_values!(RemoteFetcher, T, store, buffer, output_uid) + while true + if !isopen(our_store) && isempty(buffer) + # Only exit if the buffer is empty; otherwise, we need to + # continue draining it + break + end + stream_push_values!(output_fetcher, T, our_store, output_stream, buffer) end catch err - err isa InterruptException || rethrow(err) + if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow(err) + end finally @dagdebug thunk_id :stream "output stream closed" end @@ -243,7 +299,7 @@ function Base.close(stream::Stream) return end -function add_waiters!(stream::Stream, waiters::Vector{UInt}) +function add_waiters!(stream::Stream, waiters::Vector{Pair{UInt,Any}}) MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters add_waiters!(store::StreamStore, waiters) return @@ -508,8 +564,8 @@ function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...) end # Ensure downstream tasks also terminate - @dagdebug thunk_id :stream "closed stream" close(sf.stream) + @dagdebug thunk_id :stream "closed stream store" end end end @@ -530,6 +586,7 @@ function stream!(sf::StreamingFunction, uid, # Exit streaming on migration if sf.stream.store.migrating error("FIXME: max_evals should be retained") + @dagdebug STREAM_THUNK_ID[] :stream "returning for migration" return StreamMigrating() end @@ -538,9 +595,15 @@ function stream!(sf::StreamingFunction, uid, stream_kwarg_values = _stream_take_values!(kwarg_values) stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + if length(stream_args) > 0 || length(stream_kwarg_values) > 0 + # Notify tasks that input buffers may have space + @lock sf.stream.store.lock notify(sf.stream.store.lock) + end + # Run a single cycle of f - stream_result = f(stream_args...; stream_kwargs...) counter += 1 + @dagdebug STREAM_THUNK_ID[] :stream "executing $f (eval $counter)" + stream_result = f(stream_args...; stream_kwargs...) # Exit streaming on graceful request if stream_result isa FinishStream @@ -548,6 +611,7 @@ function stream!(sf::StreamingFunction, uid, value = something(stream_result.value) put!(sf.stream, value) end + @dagdebug STREAM_THUNK_ID[] :stream "voluntarily returning" return stream_result.result end @@ -555,8 +619,8 @@ function stream!(sf::StreamingFunction, uid, put!(sf.stream, stream_result) # Exit streaming on eval limit - if sf.max_evals >= 0 && counter >= sf.max_evals - @dagdebug STREAM_THUNK_ID[] :stream "max evals reached" + if sf.max_evals > 0 && counter >= sf.max_evals + @dagdebug STREAM_THUNK_ID[] :stream "max evals reached ($counter)" return end end @@ -596,7 +660,7 @@ function task_to_stream(uid::UInt) end function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) - stream_waiter_changes = Dict{UInt,Vector{UInt}}() + stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() for (spec, task) in tasks @assert haskey(self_streams, task.uid) @@ -614,16 +678,18 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) if other_stream !== nothing # Generate Stream handle for input - # FIXME: input_fetcher = get(spec.options, :stream_input_fetcher, RemoteFetcher) + # FIXME: Be configurable + input_fetcher = RemoteChannelFetcher() other_stream_handle = Stream(other_stream) spec.args[idx] = pos => other_stream_handle our_stream.store.input_streams[arg.uid] = other_stream_handle + our_stream.store.input_fetchers[arg.uid] = input_fetcher # Add this task as a waiter for the associated output Stream changes = get!(stream_waiter_changes, arg.uid) do - UInt[] + Pair{UInt,Any}[] end - push!(changes, task.uid) + push!(changes, task.uid => input_fetcher) end end end diff --git a/test/streaming.jl b/test/streaming.jl index a994b410..b7d3ce32 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -23,6 +23,7 @@ const ACCUMULATOR = Dict{Int,Vector{Real}}() return end @everywhere accumulator(xs...) = accumulator(sum(xs)) +@everywhere accumulator(::Nothing) = accumulator(0) function catch_interrupt(f) try @@ -60,12 +61,13 @@ function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) end return tset end - timed_out = timedwait(()->istaskdone(t), 5) == :timed_out + timed_out = timedwait(()->istaskdone(t), 10) == :timed_out if timed_out if !ignore_timeout @warn "Testing task timed out: $message" end Dagger.cancel!(;halt_sch=true) + @everywhere GC.gc() fetch(Dagger.@spawn 1+1) end tset = fetch(t)::Test.DefaultTestSet @@ -96,7 +98,7 @@ for idx in 1:5 return y end end - fetch(x) + @test_throws_unwrap InterruptException fetch(x) end @test test_finishes("Single task without result") do @@ -164,7 +166,9 @@ for idx in 1:5 Dagger.spawn_streaming() do x = Dagger.@spawn scope=rand(scopes) rand() end - A = Dagger.@spawn accumulator(x) + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + end @test fetch(x) === nothing @test fetch(A) === nothing values = copy(ACCUMULATOR); empty!(ACCUMULATOR) @@ -177,8 +181,10 @@ for idx in 1:5 Dagger.spawn_streaming() do x = Dagger.@spawn scope=rand(scopes) rand() end - A = Dagger.@spawn accumulator(x) - B = Dagger.@spawn accumulator(x) + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + B = Dagger.@spawn accumulator(x) + end @test fetch(x) === nothing @test fetch(A) === nothing @test fetch(B) === nothing @@ -364,7 +370,7 @@ for idx in 1:5 @test test_finishes("max_evals=100"; max_evals=100) do local A Dagger.spawn_streaming() do - A = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing values = copy(ACCUMULATOR); empty!(ACCUMULATOR) @@ -374,44 +380,39 @@ for idx in 1:5 end @testset "DropBuffer ($scope_str)" begin - @test test_finishes("x (drop)-> A") do + # TODO: Test that accumulator never gets called + @test !test_finishes("x (drop)-> A"; ignore_timeout=true) do local x, A Dagger.spawn_streaming() do - Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do x = Dagger.@spawn scope=rand(scopes) rand() end A = Dagger.@spawn scope=rand(scopes) accumulator(x) end - @test fetch(A) === nothing - values = copy(ACCUMULATOR); empty!(ACCUMULATOR) - A_tid = Dagger.task_id(A) - @test !haskey(values, A_tid) + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing end - @test test_finishes("x ->(drop) A") do + @test !test_finishes("x ->(drop) A"; ignore_timeout=true) do local x, A Dagger.spawn_streaming() do x = Dagger.@spawn scope=rand(scopes) rand() - Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do A = Dagger.@spawn scope=rand(scopes) accumulator(x) end end - @test fetch(A) === nothing - values = copy(ACCUMULATOR); empty!(ACCUMULATOR) - A_tid = Dagger.task_id(A) - @test !haskey(values, A_tid) + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing end - @test test_finishes("x -(drop)> A") do + @test !test_finishes("x -(drop)> A"; ignore_timeout=true) do local x, A Dagger.spawn_streaming() do - Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do x = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(x) end end - @test fetch(A) === nothing - values = copy(ACCUMULATOR); empty!(ACCUMULATOR) - A_tid = Dagger.task_id(A) - @test !haskey(values, A_tid) + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing end end