From c22b50d9cadd18d5da21154a5bea58f133695867 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Oct 2022 20:11:02 -0500 Subject: [PATCH 1/2] at-spawn: Add parsing options Add closure mode Use closure mode for `begin ... end` expressions --- src/thunk.jl | 123 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 96 insertions(+), 27 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 92b6de589..249ef4545 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -332,7 +332,7 @@ generated thunks. macro par(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=true, opts=opts) + generate_spawn(ex; lazy=true, opts=opts) end """ @@ -346,7 +346,7 @@ See the docs for `@par` for more information and usage examples. macro spawn(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=false, opts=opts) + generate_spawn(ex; lazy=false, opts=opts) end struct ExpandedBroadcast{F} end @@ -360,40 +360,109 @@ function replace_broadcast(fn::Symbol) return fn end -function _par(ex::Expr; lazy=true, recur=true, opts=()) - if ex.head == :call && recur - f = replace_broadcast(ex.args[1]) - if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters) - args = ex.args[3:end] - kwargs = ex.args[2] - else - args = ex.args[2:end] - kwargs = Expr(:parameters) +function generate_spawn(ex::Expr; lazy=true, mode=nothing, opts=()) + if mode === nothing + parse_idx = nothing + for (idx, opt) in enumerate(opts) + if Meta.isexpr(opt, :(=)) && opt.args[1] == :parse + if parse_idx !== nothing + throw(ArgumentError("`parse` can only be specified once")) + end + if !(opt.args[2] isa QuoteNode) + throw(ArgumentError("`parse` option value must be a constant Symbol")) + end + mode = opt.args[2].value + if !(mode in (:closure, :call)) # TODO: :recurse + throw(ArgumentError("Invalid parse mode: $(repr(mode))")) + end + parse_idx = idx + end end - opts = esc.(opts) - args_ex = _par.(args; lazy=lazy, recur=false) - kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false) - if lazy - return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) + if parse_idx !== nothing + opts = (opts[1:(parse_idx-1)]..., opts[(parse_idx+1):end]...) + end + end + if mode === nothing + # Automatically pick a mode + if Meta.isexpr(ex, :call) + mode = :call else - sync_var = esc(Base.sync_varname) - @gensym result - return quote - let args = ($(args_ex...),) - $result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...)) - if $(Expr(:islocal, sync_var)) - put!($sync_var, schedule(Task(()->wait($result)))) + mode = :closure + end + end + if mode == :call + if !Meta.isexpr(ex, :call) + throw(ArgumentError("When `parse=:call`, expression must be a function call")) + end + f = esc(replace_broadcast(ex.args[1])) + has_kw(ex) = length(ex.args) >= 2 && + (Meta.isexpr(ex.args[2], :parameters) || + any(iex->Meta.isexpr(iex, :kw), ex.args)) + if has_kw(ex) + if Meta.isexpr(ex.args[2], :parameters) + args = ex.args[3:end] + kwargs = ex.args[2].args + else + kwargs = Expr[] + for argidx in length(ex.args):-1:2 + arg = ex.args[argidx] + if Meta.isexpr(arg, :kw) + pushfirst!(kwargs, arg) + deleteat!(ex.args, argidx) end - $result end + args = ex.args[2:end] end + else + args = ex.args[2:end] + kwargs = Expr[] + end + args = map(esc, args) + kwargs = map(esc, kwargs) + elseif mode == :closure + f = :(()->$(esc(ex))) + args = [] + kwargs = Expr[] + #= TODO: Recurse through AST + elseif mode == :recur + if Meta.isexpr(ex, :(=)) + return Expr(:(=), ex.args[1], generate_spawn(ex.args[2]; lazy, mode, opts)) + elseif Meta.isexpr(ex, :block) || + Meta.isexpr(ex, :tuple) + return Expr(ex.head, map(arg->generate_spawn(arg; lazy, mode, opts), ex.args)...) + elseif Meta.isexpr(ex, :if) + cond = ex.args[1] + cond = Expr(:call, :fetch, generate_spawn(cond; lazy, mode, opts)) + return Expr(:if, cond, map(arg->generate_spawn(arg; lazy, mode, opts), ex.args[2:end])...) + elseif Meta.isexpr(ex, :call) + # FIXME: Handle recursive calls + return generate_spawn(ex; lazy, mode=:call, opts) + else + return ex + end + =# + end + opts = map(esc, opts) + if lazy + return quote + $delayed($(esc(f)), $Options(;$(opts...)))($(args...); $(kwargs...)) end else - return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...) + sync_var = esc(Base.sync_varname) + @gensym result + return quote + let + $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + if $(Expr(:islocal, sync_var)) + put!($sync_var, $schedule($Task(()->wait($result)))) + end + $result + end + end end end -_par(ex::Symbol; kwargs...) = esc(ex) -_par(ex; kwargs...) = ex +generate_spawn(ex::Symbol; kwargs...) = ex +generate_spawn(ex; kwargs...) = ex persist!(t::Thunk) = (t.persist=true; t) cache_result!(t::Thunk) = (t.cache=true; t) From dad0bbcece2f202f31c36b731b27d6a8e9575ab9 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 15 Aug 2023 11:03:55 -0500 Subject: [PATCH 2/2] TEMP: Add Thunk locator --- Manifest.toml | 6 ++++++ Project.toml | 2 ++ src/Dagger.jl | 2 ++ src/utils/find-thunks.jl | 6 ++++++ 4 files changed, 16 insertions(+) create mode 100644 src/utils/find-thunks.jl diff --git a/Manifest.toml b/Manifest.toml index af053bf85..8ed1050d8 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -3,6 +3,12 @@ julia_version = "1.7.3" manifest_format = "2.0" +[[deps.Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.4.0" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" diff --git a/Project.toml b/Project.toml index b7efa5b7e..b2220eaf1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54" version = "0.18.1" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -21,6 +22,7 @@ TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] +Adapt = "1, 2, 3" ContextVariablesX = "0.1" DataStructures = "0.18" MacroTools = "0.5" diff --git a/src/Dagger.jl b/src/Dagger.jl index 6bf708b6d..e1f40c1ad 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -14,12 +14,14 @@ using UUIDs import ContextVariablesX +import Adapt using Requires using MacroTools using TimespanLogging include("lib/util.jl") include("utils/dagdebug.jl") +include("utils/find-thunk.jl") # Distributed data include("options.jl") diff --git a/src/utils/find-thunks.jl b/src/utils/find-thunks.jl new file mode 100644 index 000000000..4894c2826 --- /dev/null +++ b/src/utils/find-thunks.jl @@ -0,0 +1,6 @@ +struct RecordAdaptor + tasks::Set{Any} +end +struct FetchAdaptor end +Adapt.adapt_storage(ra::RecordAdaptor, t::Thunk) = (push!(ra.tasks, t); t) +Adapt.adapt_storage(::FetchAdaptor, t::Thunk) = fetch(t)