diff --git a/src/simulation.jl b/src/simulation.jl index c0b65288..acddc5c3 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -37,7 +37,7 @@ Base.Expr(c::UnaryCall) = begin operator = c.operator if c.equality == :.= # TODO: Generalize to inplacable functions - if operator == add_inplace_stub(:⋆₁⁻¹) # Since inverse hodge Geo is a solver + if operator == inplace(:⋆₁⁻¹) # Since inverse hodge Geo is a solver Expr(:call, c.operator, c.output, c.input) elseif operator == :.- Expr(c.equality, c.output, Expr(:call, operator, c.input)) @@ -213,7 +213,7 @@ function get_stub(var_name::Symbol) return Symbol(var_str[begin:first(idx) - 1]) end -add_inplace_stub(var_name::Symbol) = add_stub(GENSIM_INPLACE_STUB, var_name) +inplace(var_name::Symbol) = add_stub(GENSIM_INPLACE_STUB, var_name) const ARITHMETIC_OPS = Set([:+, :*, :-, :/, :.+, :.*, :.-, :./, :^, :.^, :.>, :.<, :.≤, :.≥]) @@ -357,18 +357,10 @@ function hook_STC_settvar(state_name::Symbol, tgt_name::Symbol, ::Union{CPUBacke return :(setproperty!(du, $ssymb, $tgt_name)) end -const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+, - :(-) => :.-, - :(*) => :.*, - :(/) => :./, - :(^) => :.^, - :(=) => :.=, - :.+ => :.+, - :.- => :.-, - :.* => :.*, - :./ => :./, - :.^ => :.^, - :.= => :.=) +const PROMOTABLE_OPS = [:(+), :(-), :(*), :(/), :(^), :(=)] +function promote_arith_op(sym::Symbol) + sym ∈ PROMOTABLE_OPS ? Symbol(".$sym") : sym +end """ compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool) @@ -380,143 +372,73 @@ in-place methods, `dimension` is the dimension of the problem (usually 1 or 2), which is set to `true` by default and determines if intermediate results can be preallocated.. """ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool) - # Get the Vars of the inputs (probably state Vars). - visited_Var = falses(nparts(d, :Var)) - - # TODO: Pass in state indices instead of names - input_numbers = reduce(vcat, incident(d, inputs, :name)) - - visited_Var[input_numbers] .= true - visited_Var[incident(d, :Literal, :type)] .= true - - # TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ - visited_1 = falses(nparts(d, :Op1)) - visited_2 = falses(nparts(d, :Op2)) - visited_Σ = falses(nparts(d, :Σ)) - - # FIXME: this is a quadratic implementation of topological_sort inlined in here. - op_order = AbstractCall[] - - for _ in 1:(nparts(d, :Op1) + nparts(d,:Op2) + nparts(d, :Σ)) - for op in parts(d, :Op1) - s = d[op, :src] - if !visited_1[op] && visited_Var[s] - # skip the derivative edges - operator = d[op, :op1] - t = d[op, :tgt] - visited_1[op] = true - if operator == DerivOp - continue - end - - equality = :(=) - - sname = d[s, :name] - tname = d[t, :name] - - # TODO: Check to see if this is a DEC operator - if preallocate && is_form(d, t) - if operator in optimizable_dec_operators - equality = PROMOTE_ARITHMETIC_MAP[equality] - operator = add_stub(GENSIM_INPLACE_STUB, operator) - push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target)) - - elseif operator == :(-) || operator == :.- - equality = PROMOTE_ARITHMETIC_MAP[equality] - operator = PROMOTE_ARITHMETIC_MAP[operator] - push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target)) - end - end - - visited_Var[t] = true - c = UnaryCall(operator, equality, sname, tname) - push!(op_order, c) - end + + AVC(name, type) = + push!(alloc_vectors, AllocVecCall(name, type, dimension, stateeltype, code_target)) + + ops = AbstractCall[] + + function visit_op1(op) + s, t = only(op.dom), only(op.cod) + + op.name ∈ [:∂ₜ, :dt] && return + + UC(name, eq) = + push!(ops, UnaryCall(name, eq, d[s, :name], d[t, :name])) + + if preallocate && is_form(d, t) && op.name ∈ optimizable_dec_operators + AVC(d[t, :name], d[t, :type]) + UC(inplace(op.name), :.=) + else + UC(op.name, :(=)) end + end - for op in parts(d, :Op2) - arg1 = d[op, :proj1] - arg2 = d[op, :proj2] - if !visited_2[op] && visited_Var[arg1] && visited_Var[arg2] - r = d[op, :res] - a1name = d[arg1, :name] - a2name = d[arg2, :name] - rname = d[r, :name] - - operator = d[op, :op2] - equality = :(=) - - # TODO: Check to make sure that this logic never breaks - if preallocate && is_form(d, r) - if operator == :(+) || operator == :(-) || operator == :.+ || operator == :.- - operator = PROMOTE_ARITHMETIC_MAP[operator] - equality = PROMOTE_ARITHMETIC_MAP[equality] - push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target)) - - # TODO: Do we want to support the ability of a user to use the backslash operator? - elseif operator == :(*) || operator == :(/) || operator == :.* || operator == :./ - # ! WARNING: This part may break if we add more compiler types that have different - # ! operations for basic and broadcast modes, e.g. matrix multiplication vs broadcast - if !is_infer(d, arg1) && !is_infer(d, arg2) - operator = PROMOTE_ARITHMETIC_MAP[operator] - equality = PROMOTE_ARITHMETIC_MAP[equality] - push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target)) - end - elseif operator in optimizable_dec_operators - operator = add_stub(GENSIM_INPLACE_STUB, operator) - equality = PROMOTE_ARITHMETIC_MAP[equality] - push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target)) - end - end - - # TODO: Clean this in another PR (with a @match maybe). - if operator == :(*) - operator = PROMOTE_ARITHMETIC_MAP[operator] - end - if operator == :(-) - operator = PROMOTE_ARITHMETIC_MAP[operator] - end - if operator == :(/) - operator = PROMOTE_ARITHMETIC_MAP[operator] - end - if operator == :(^) - operator = PROMOTE_ARITHMETIC_MAP[operator] - end - - visited_2[op] = true - visited_Var[r] = true - c = BinaryCall(operator, equality, a1name, a2name, rname) - push!(op_order, c) - end + function visit_op2(op) + arg1, arg2, r = op.dom..., op.cod + + BC(name, eq) = + push!(ops, BinaryCall(name, eq, d[arg1, :name], d[arg2, :name], d[r, :name])) + + if !(preallocate && is_form(d,r)) + BC(promote_arith_op(op.name), :(=)) + elseif op.name ∈ optimizable_dec_operators + AVC(d[r, :name], d[r, :type]) + BC(inplace(op.name), :.=) + else + AVC(d[r, :name], d[r, :type]) + BC(promote_arith_op(op.name), :.=) end + end - for op in parts(d, :Σ) - args = subpart(d, incident(d, op, :summation), :summand) - if !visited_Σ[op] && all(visited_Var[args]) - r = d[op, :sum] - argnames = d[args, :name] - rname = d[r, :name] - - # operator = :(+) - operator = :.+ - equality = :(=) - - # If result is a known form, broadcast addition - if preallocate && is_form(d, r) - operator = PROMOTE_ARITHMETIC_MAP[operator] - equality = PROMOTE_ARITHMETIC_MAP[equality] - push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target)) - end - - visited_Σ[op] = true - visited_Var[r] = true - c = SummationCall(equality, argnames, rname) - push!(op_order, c) - end + function visit_Σ(op) + argnames = d[op.dom, :name] + r = d[op.index, :sum] + rname = d[r, :name] + + SC(eq) = push!(ops, SummationCall(eq, argnames, rname)) + + if !(preallocate && is_form(d, r)) + SC(:(=)) + else + AVC(rname, d[r, :type]) + SC(:.=) + end + end + + for op in topological_sort_edges(d) + if op.name == :+ + visit_Σ(op) + elseif length(op.dom) == 1 + visit_op1(op) + elseif length(op.dom) == 2 + visit_op2(op) + else + error("Unknown operation type") end end - eq_exprs = Expr.(op_order) + Expr.(ops) end """ @@ -624,7 +546,7 @@ function link_contract_operators(d::SummationDecapode, con_dec_operators::Set{Sy for op1_id in parts(d, :Op1) op1_name = d[op1_id, :op1] if isa(op1_name, AbstractArray) - computation = reverse!(map(x -> add_inplace_stub(x), op1_name)) + computation = reverse!(map(x -> inplace(x), op1_name)) compute_key = join(computation, " * ") computation_name = get(compute_to_name, compute_key, :Error) @@ -636,7 +558,7 @@ function link_contract_operators(d::SummationDecapode, con_dec_operators::Set{Sy expr_line = hook_LCO_inplace(computation_name, computation, stateeltype, code_target) push!(contract_defs.args, expr_line) - expr_line = Expr(Symbol("="), computation_name, Expr(Symbol("->"), :x, Expr(:call, :*, add_inplace_stub(computation_name), :x))) + expr_line = Expr(Symbol("="), computation_name, Expr(Symbol("->"), :x, Expr(:call, :*, inplace(computation_name), :x))) push!(contract_defs.args, expr_line) curr_id += 1 @@ -651,7 +573,7 @@ end # TODO: Allow user to overload these hooks with user-defined code_target function hook_LCO_inplace(computation_name::Symbol, computation::Vector{Symbol}, float_type::DataType, ::CPUBackend) - return :($(add_inplace_stub(computation_name)) = $(Expr(:call, :*, computation...))) + return :($(inplace(computation_name)) = $(Expr(:call, :*, computation...))) end function generate_parentheses_multiply(list) @@ -663,7 +585,7 @@ function generate_parentheses_multiply(list) end function hook_LCO_inplace(computation_name::Symbol, computation::Vector{Symbol}, float_type::DataType, ::CUDABackend) - return :($(add_inplace_stub(computation_name)) = $(generate_parentheses_multiply(computation))) + return :($(inplace(computation_name)) = $(generate_parentheses_multiply(computation))) end struct UnsupportedDimensionException <: Exception diff --git a/test/simulation_core.jl b/test/simulation_core.jl index 6aa2a9d4..69823f3d 100644 --- a/test/simulation_core.jl +++ b/test/simulation_core.jl @@ -13,7 +13,7 @@ const DOT_EQUALS =:.= # Calling Code Tests # ###################### -import Decapodes: UnaryCall, add_inplace_stub +import Decapodes: UnaryCall, inplace @testset "Test UnaryCall" begin # Test equality, basic operator @@ -34,7 +34,7 @@ import Decapodes: UnaryCall, add_inplace_stub @test Expr(UnaryCall(:.-, DOT_EQUALS, :x, :y)) == :(y .= .-x) # Test broadcast equality, inplace non-matrix method - let inplace_op = add_inplace_stub(:⋆₁⁻¹) + let inplace_op = inplace(:⋆₁⁻¹) @test Expr(UnaryCall(inplace_op, DOT_EQUALS, :x, :y)) == :($inplace_op(y, x)) end end @@ -46,7 +46,7 @@ import Decapodes: BinaryCall @test Expr(BinaryCall(:F, EQUALS, :x, :y, :z)) == :(z = F(x, y)) # Test broadcast equality, inplace operator - let inplace_operator = add_inplace_stub(:F) + let inplace_operator = inplace(:F) @test Expr(BinaryCall(inplace_operator, DOT_EQUALS, :x, :y, :z)) == :($inplace_operator(z, x, y)) end end