Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic data prefetch #199

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ Fields
- ready::Vector{Thunk} - The list of `Thunk`s that are ready to execute
- cache::Dict{Thunk, Any} - Maps from a finished `Thunk` to it's cached result, often a DRef
- running::Set{Thunk} - The set of currently-running `Thunk`s
- thunk_dict::Dict{Int, Any} - Maps from thunk IDs to a `Thunk`
- running_worker::Dict{Thunk,Int} - Map from running `Thunk` to its worker
- thunk_dict::Dict{Int, Thunk} - Maps from thunk IDs to a `Thunk`
- node_order::Any - Function that returns the order of a thunk
- worker_pressure::Dict{Int,Int} - Cache of worker pressure
- worker_capacity::Dict{Int,Int} - Maps from worker ID to capacity
Expand All @@ -44,7 +45,8 @@ struct ComputeState
ready::Vector{Thunk}
cache::Dict{Thunk, Any}
running::Set{Thunk}
thunk_dict::Dict{Int, Any}
running_worker::Dict{Thunk, Int}
thunk_dict::Dict{Int, Thunk}
node_order::Any
worker_pressure::Dict{Int,Int}
worker_capacity::Dict{Int,Int}
Expand Down Expand Up @@ -345,8 +347,34 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
pop_and_fire!(ctx, state, p)
end
end
schedule_memory!(ctx, state)
end
end

function schedule_memory!(ctx, state)
sent = Set{Thunk}() # TODO: Store this in state
for up in state.running
proc = state.running_worker[up]
for down in state.dependents[up]
down in sent && continue
push!(sent, down)
for (idx,input) in filter(x->!istask(x[2]), collect(enumerate(down.inputs)))
remote_do(prefetch, proc, down.id, idx, Dagger.tochunk(input))
end
end
end
end

const CACHED_INPUTS = Dict{Int,Dict{Int,Any}}()

function prefetch(tid, idx, chunk)
d = get!(()->Dict{Int,Any}(), CACHED_INPUTS, tid)
if !haskey(d, idx)
d[idx] = move(OSProc(), chunk)
end
nothing
end

function pop_and_fire!(ctx, state, proc)
task = pop_with_affinity!(ctx, state.ready, proc)
if task !== nothing
Expand Down Expand Up @@ -421,6 +449,7 @@ end

function fire_task!(ctx, thunk, proc, state)
push!(state.running, thunk)
state.running_worker[thunk] = proc.pid
if thunk.cache && thunk.cache_ref !== nothing
# the result might be already cached
data = unrelease(thunk.cache_ref) # ask worker to keep the data around
Expand Down Expand Up @@ -469,6 +498,7 @@ end

function finish_task!(state, node, thunk_failed; free=true)
pop!(state.running, node)
delete!(state.running_worker, node)
if !thunk_failed
push!(state.finished, node)
else
Expand Down Expand Up @@ -526,6 +556,7 @@ function start_state(deps::Dict, node_order, chan)
Vector{Thunk}(undef, 0),
Dict{Thunk, Any}(),
Set{Thunk}(),
Dict{Thunk, Int}(),
Dict{Int, Thunk}(),
node_order,
Dict{Int,Int}(),
Expand Down Expand Up @@ -568,7 +599,12 @@ end
fetch.(map(Iterators.zip(data,ids)) do (x, id)
@async begin
@dbg timespan_start(ctx, :move, (thunk_id, id), (f, id))
x = move(to_proc, x)
x = if haskey(CACHED_INPUTS, thunk_id) &&
haskey(CACHED_INPUTS[thunk_id], abs(id))
CACHED_INPUTS[thunk_id][abs(id)]
else
move(to_proc, x)
end
@dbg timespan_end(ctx, :move, (thunk_id, id), (f, id))
return x
end
Expand All @@ -592,6 +628,9 @@ end
bt = catch_backtrace()
RemoteException(myid(), CapturedException(ex, bt))
end
if haskey(CACHED_INPUTS, thunk_id)
delete!(CACHED_INPUTS, thunk_id)
end
@dbg timespan_end(ctx, :compute, thunk_id, (f, to_proc, typeof(res), sizeof(res)))
metadata = (pressure=ACTIVE_TASKS[],)
(result_meta, metadata)
Expand Down