diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 82cedbbe2..88ab68dbd 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -25,7 +25,8 @@ Fields - ready::Vector{Thunk} - The list of `Thunk`s that are ready to execute - cache::Dict{Thunk, Any} - Maps from a finished `Thunk` to it's cached result, often a DRef - running::Set{Thunk} - The set of currently-running `Thunk`s -- thunk_dict::Dict{Int, Any} - Maps from thunk IDs to a `Thunk` +- running_worker::Dict{Thunk,Int} - Map from running `Thunk` to its worker +- thunk_dict::Dict{Int, Thunk} - Maps from thunk IDs to a `Thunk` - node_order::Any - Function that returns the order of a thunk - worker_pressure::Dict{Int,Int} - Cache of worker pressure - worker_capacity::Dict{Int,Int} - Maps from worker ID to capacity @@ -44,7 +45,8 @@ struct ComputeState ready::Vector{Thunk} cache::Dict{Thunk, Any} running::Set{Thunk} - thunk_dict::Dict{Int, Any} + running_worker::Dict{Thunk, Int} + thunk_dict::Dict{Int, Thunk} node_order::Any worker_pressure::Dict{Int,Int} worker_capacity::Dict{Int,Int} @@ -345,8 +347,34 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) pop_and_fire!(ctx, state, p) end end + schedule_memory!(ctx, state) end end + +function schedule_memory!(ctx, state) + sent = Set{Thunk}() # TODO: Store this in state + for up in state.running + proc = state.running_worker[up] + for down in state.dependents[up] + down in sent && continue + push!(sent, down) + for (idx,input) in filter(x->!istask(x[2]), collect(enumerate(down.inputs))) + remote_do(prefetch, proc, down.id, idx, Dagger.tochunk(input)) + end + end + end +end + +const CACHED_INPUTS = Dict{Int,Dict{Int,Any}}() + +function prefetch(tid, idx, chunk) + d = get!(()->Dict{Int,Any}(), CACHED_INPUTS, tid) + if !haskey(d, idx) + d[idx] = move(OSProc(), chunk) + end + nothing +end + function pop_and_fire!(ctx, state, proc) task = pop_with_affinity!(ctx, state.ready, proc) if task !== nothing @@ -421,6 +449,7 @@ end function fire_task!(ctx, thunk, proc, state) push!(state.running, thunk) + state.running_worker[thunk] = proc.pid if thunk.cache && thunk.cache_ref !== nothing # the result might be already cached data = unrelease(thunk.cache_ref) # ask worker to keep the data around @@ -469,6 +498,7 @@ end function finish_task!(state, node, thunk_failed; free=true) pop!(state.running, node) + delete!(state.running_worker, node) if !thunk_failed push!(state.finished, node) else @@ -526,6 +556,7 @@ function start_state(deps::Dict, node_order, chan) Vector{Thunk}(undef, 0), Dict{Thunk, Any}(), Set{Thunk}(), + Dict{Thunk, Int}(), Dict{Int, Thunk}(), node_order, Dict{Int,Int}(), @@ -568,7 +599,12 @@ end fetch.(map(Iterators.zip(data,ids)) do (x, id) @async begin @dbg timespan_start(ctx, :move, (thunk_id, id), (f, id)) - x = move(to_proc, x) + x = if haskey(CACHED_INPUTS, thunk_id) && + haskey(CACHED_INPUTS[thunk_id], abs(id)) + CACHED_INPUTS[thunk_id][abs(id)] + else + move(to_proc, x) + end @dbg timespan_end(ctx, :move, (thunk_id, id), (f, id)) return x end @@ -592,6 +628,9 @@ end bt = catch_backtrace() RemoteException(myid(), CapturedException(ex, bt)) end + if haskey(CACHED_INPUTS, thunk_id) + delete!(CACHED_INPUTS, thunk_id) + end @dbg timespan_end(ctx, :compute, thunk_id, (f, to_proc, typeof(res), sizeof(res))) metadata = (pressure=ACTIVE_TASKS[],) (result_meta, metadata)