Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add closure parsing to at-spawn #423

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
julia_version = "1.7.3"
manifest_format = "2.0"

[[deps.Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.4.0"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
version = "0.18.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -21,6 +22,7 @@ TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
Adapt = "1, 2, 3"
ContextVariablesX = "0.1"
DataStructures = "0.18"
MacroTools = "0.5"
Expand Down
2 changes: 2 additions & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ using UUIDs

import ContextVariablesX

import Adapt
using Requires
using MacroTools
using TimespanLogging

include("lib/util.jl")
include("utils/dagdebug.jl")
include("utils/find-thunk.jl")

# Distributed data
include("options.jl")
Expand Down
123 changes: 96 additions & 27 deletions src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ generated thunks.
macro par(exs...)
opts = exs[1:end-1]
ex = exs[end]
_par(ex; lazy=true, opts=opts)
generate_spawn(ex; lazy=true, opts=opts)
end

"""
Expand All @@ -346,7 +346,7 @@ See the docs for `@par` for more information and usage examples.
macro spawn(exs...)
opts = exs[1:end-1]
ex = exs[end]
_par(ex; lazy=false, opts=opts)
generate_spawn(ex; lazy=false, opts=opts)
end

struct ExpandedBroadcast{F} end
Expand All @@ -360,40 +360,109 @@ function replace_broadcast(fn::Symbol)
return fn
end

function _par(ex::Expr; lazy=true, recur=true, opts=())
if ex.head == :call && recur
f = replace_broadcast(ex.args[1])
if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters)
args = ex.args[3:end]
kwargs = ex.args[2]
else
args = ex.args[2:end]
kwargs = Expr(:parameters)
function generate_spawn(ex::Expr; lazy=true, mode=nothing, opts=())
if mode === nothing
parse_idx = nothing
for (idx, opt) in enumerate(opts)
if Meta.isexpr(opt, :(=)) && opt.args[1] == :parse
if parse_idx !== nothing
throw(ArgumentError("`parse` can only be specified once"))
end
if !(opt.args[2] isa QuoteNode)
throw(ArgumentError("`parse` option value must be a constant Symbol"))
end
mode = opt.args[2].value
if !(mode in (:closure, :call)) # TODO: :recurse
throw(ArgumentError("Invalid parse mode: $(repr(mode))"))
end
parse_idx = idx
end
end
opts = esc.(opts)
args_ex = _par.(args; lazy=lazy, recur=false)
kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false)
if lazy
return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...)))
if parse_idx !== nothing
opts = (opts[1:(parse_idx-1)]..., opts[(parse_idx+1):end]...)
end
end
if mode === nothing
# Automatically pick a mode
if Meta.isexpr(ex, :call)
mode = :call
else
sync_var = esc(Base.sync_varname)
@gensym result
return quote
let args = ($(args_ex...),)
$result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...))
if $(Expr(:islocal, sync_var))
put!($sync_var, schedule(Task(()->wait($result))))
mode = :closure
end
end
if mode == :call
if !Meta.isexpr(ex, :call)
throw(ArgumentError("When `parse=:call`, expression must be a function call"))
end
f = esc(replace_broadcast(ex.args[1]))
has_kw(ex) = length(ex.args) >= 2 &&
(Meta.isexpr(ex.args[2], :parameters) ||
any(iex->Meta.isexpr(iex, :kw), ex.args))
if has_kw(ex)
if Meta.isexpr(ex.args[2], :parameters)
args = ex.args[3:end]
kwargs = ex.args[2].args
else
kwargs = Expr[]
for argidx in length(ex.args):-1:2
arg = ex.args[argidx]
if Meta.isexpr(arg, :kw)
pushfirst!(kwargs, arg)
deleteat!(ex.args, argidx)
end
$result
end
args = ex.args[2:end]
end
else
args = ex.args[2:end]
kwargs = Expr[]
end
args = map(esc, args)
kwargs = map(esc, kwargs)
elseif mode == :closure
f = :(()->$(esc(ex)))
args = []
kwargs = Expr[]
#= TODO: Recurse through AST
elseif mode == :recur
if Meta.isexpr(ex, :(=))
return Expr(:(=), ex.args[1], generate_spawn(ex.args[2]; lazy, mode, opts))
elseif Meta.isexpr(ex, :block) ||
Meta.isexpr(ex, :tuple)
return Expr(ex.head, map(arg->generate_spawn(arg; lazy, mode, opts), ex.args)...)
elseif Meta.isexpr(ex, :if)
cond = ex.args[1]
cond = Expr(:call, :fetch, generate_spawn(cond; lazy, mode, opts))
return Expr(:if, cond, map(arg->generate_spawn(arg; lazy, mode, opts), ex.args[2:end])...)
elseif Meta.isexpr(ex, :call)
# FIXME: Handle recursive calls
return generate_spawn(ex; lazy, mode=:call, opts)
else
return ex
end
=#
end
opts = map(esc, opts)
if lazy
return quote
$delayed($(esc(f)), $Options(;$(opts...)))($(args...); $(kwargs...))
end
else
return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
sync_var = esc(Base.sync_varname)
@gensym result
return quote
let
$result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...))
if $(Expr(:islocal, sync_var))
put!($sync_var, $schedule($Task(()->wait($result))))
end
$result
end
end
end
end
_par(ex::Symbol; kwargs...) = esc(ex)
_par(ex; kwargs...) = ex
generate_spawn(ex::Symbol; kwargs...) = ex
generate_spawn(ex; kwargs...) = ex

persist!(t::Thunk) = (t.persist=true; t)
cache_result!(t::Thunk) = (t.cache=true; t)
Expand Down
6 changes: 6 additions & 0 deletions src/utils/find-thunks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
struct RecordAdaptor
tasks::Set{Any}
end
struct FetchAdaptor end
Adapt.adapt_storage(ra::RecordAdaptor, t::Thunk) = (push!(ra.tasks, t); t)
Adapt.adapt_storage(::FetchAdaptor, t::Thunk) = fetch(t)
Loading