Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use traversal abstraction #270

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 70 additions & 148 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Base.Expr(c::UnaryCall) = begin
operator = c.operator
if c.equality == :.=
# TODO: Generalize to inplacable functions
if operator == add_inplace_stub(:⋆₁⁻¹) # Since inverse hodge Geo is a solver
if operator == inplace(:⋆₁⁻¹) # Since inverse hodge Geo is a solver
Expr(:call, c.operator, c.output, c.input)
elseif operator == :.-
Expr(c.equality, c.output, Expr(:call, operator, c.input))
Expand Down Expand Up @@ -213,7 +213,7 @@ function get_stub(var_name::Symbol)
return Symbol(var_str[begin:first(idx) - 1])
end

add_inplace_stub(var_name::Symbol) = add_stub(GENSIM_INPLACE_STUB, var_name)
inplace(var_name::Symbol) = add_stub(GENSIM_INPLACE_STUB, var_name)

const ARITHMETIC_OPS = Set([:+, :*, :-, :/, :.+, :.*, :.-, :./, :^, :.^, :.>, :.<, :.≤, :.≥])

Expand Down Expand Up @@ -357,18 +357,10 @@ function hook_STC_settvar(state_name::Symbol, tgt_name::Symbol, ::Union{CPUBacke
return :(setproperty!(du, $ssymb, $tgt_name))
end

const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+,
:(-) => :.-,
:(*) => :.*,
:(/) => :./,
:(^) => :.^,
:(=) => :.=,
:.+ => :.+,
:.- => :.-,
:.* => :.*,
:./ => :./,
:.^ => :.^,
:.= => :.=)
const PROMOTABLE_OPS = [:(+), :(-), :(*), :(/), :(^), :(=)]
function promote_arith_op(sym::Symbol)
sym ∈ PROMOTABLE_OPS ? Symbol(".$sym") : sym
end

"""
compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
Expand All @@ -380,143 +372,73 @@ in-place methods, `dimension` is the dimension of the problem (usually 1 or 2),
which is set to `true` by default and determines if intermediate results can be preallocated..
"""
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
# Get the Vars of the inputs (probably state Vars).
visited_Var = falses(nparts(d, :Var))

# TODO: Pass in state indices instead of names
input_numbers = reduce(vcat, incident(d, inputs, :name))

visited_Var[input_numbers] .= true
visited_Var[incident(d, :Literal, :type)] .= true

# TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ
visited_1 = falses(nparts(d, :Op1))
visited_2 = falses(nparts(d, :Op2))
visited_Σ = falses(nparts(d, :Σ))

# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = AbstractCall[]

for _ in 1:(nparts(d, :Op1) + nparts(d,:Op2) + nparts(d, :Σ))
for op in parts(d, :Op1)
s = d[op, :src]
if !visited_1[op] && visited_Var[s]
# skip the derivative edges
operator = d[op, :op1]
t = d[op, :tgt]
visited_1[op] = true
if operator == DerivOp
continue
end

equality = :(=)

sname = d[s, :name]
tname = d[t, :name]

# TODO: Check to see if this is a DEC operator
if preallocate && is_form(d, t)
if operator in optimizable_dec_operators
equality = PROMOTE_ARITHMETIC_MAP[equality]
operator = add_stub(GENSIM_INPLACE_STUB, operator)
push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target))

elseif operator == :(-) || operator == :.-
equality = PROMOTE_ARITHMETIC_MAP[equality]
operator = PROMOTE_ARITHMETIC_MAP[operator]
push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target))
end
end

visited_Var[t] = true
c = UnaryCall(operator, equality, sname, tname)
push!(op_order, c)
end

AVC(name, type) =
push!(alloc_vectors, AllocVecCall(name, type, dimension, stateeltype, code_target))

ops = AbstractCall[]

function visit_op1(op)
s, t = only(op.dom), only(op.cod)

op.name ∈ [:∂ₜ, :dt] && return

UC(name, eq) =
push!(ops, UnaryCall(name, eq, d[s, :name], d[t, :name]))

if preallocate && is_form(d, t) && op.name ∈ optimizable_dec_operators
AVC(d[t, :name], d[t, :type])
UC(inplace(op.name), :.=)
else
UC(op.name, :(=))
end
end

for op in parts(d, :Op2)
arg1 = d[op, :proj1]
arg2 = d[op, :proj2]
if !visited_2[op] && visited_Var[arg1] && visited_Var[arg2]
r = d[op, :res]
a1name = d[arg1, :name]
a2name = d[arg2, :name]
rname = d[r, :name]

operator = d[op, :op2]
equality = :(=)

# TODO: Check to make sure that this logic never breaks
if preallocate && is_form(d, r)
if operator == :(+) || operator == :(-) || operator == :.+ || operator == :.-
operator = PROMOTE_ARITHMETIC_MAP[operator]
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))

# TODO: Do we want to support the ability of a user to use the backslash operator?
elseif operator == :(*) || operator == :(/) || operator == :.* || operator == :./
# ! WARNING: This part may break if we add more compiler types that have different
# ! operations for basic and broadcast modes, e.g. matrix multiplication vs broadcast
if !is_infer(d, arg1) && !is_infer(d, arg2)
operator = PROMOTE_ARITHMETIC_MAP[operator]
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end
elseif operator in optimizable_dec_operators
operator = add_stub(GENSIM_INPLACE_STUB, operator)
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end
end

# TODO: Clean this in another PR (with a @match maybe).
if operator == :(*)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end
if operator == :(-)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end
if operator == :(/)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end
if operator == :(^)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end

visited_2[op] = true
visited_Var[r] = true
c = BinaryCall(operator, equality, a1name, a2name, rname)
push!(op_order, c)
end
function visit_op2(op)
arg1, arg2, r = op.dom..., op.cod

BC(name, eq) =
push!(ops, BinaryCall(name, eq, d[arg1, :name], d[arg2, :name], d[r, :name]))

if !(preallocate && is_form(d,r))
BC(promote_arith_op(op.name), :(=))
elseif op.name ∈ optimizable_dec_operators
AVC(d[r, :name], d[r, :type])
BC(inplace(op.name), :.=)
else
AVC(d[r, :name], d[r, :type])
BC(promote_arith_op(op.name), :.=)
end
end

for op in parts(d, :Σ)
args = subpart(d, incident(d, op, :summation), :summand)
if !visited_Σ[op] && all(visited_Var[args])
r = d[op, :sum]
argnames = d[args, :name]
rname = d[r, :name]

# operator = :(+)
operator = :.+
equality = :(=)

# If result is a known form, broadcast addition
if preallocate && is_form(d, r)
operator = PROMOTE_ARITHMETIC_MAP[operator]
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end

visited_Σ[op] = true
visited_Var[r] = true
c = SummationCall(equality, argnames, rname)
push!(op_order, c)
end
function visit_Σ(op)
argnames = d[op.dom, :name]
r = d[op.index, :sum]
rname = d[r, :name]

SC(eq) = push!(ops, SummationCall(eq, argnames, rname))

if !(preallocate && is_form(d, r))
SC(:(=))
else
AVC(rname, d[r, :type])
SC(:.=)
end
end

for op in topological_sort_edges(d)
if op.name == :+
visit_Σ(op)
elseif length(op.dom) == 1
visit_op1(op)
elseif length(op.dom) == 2
visit_op2(op)
else
error("Unknown operation type")
end
end

eq_exprs = Expr.(op_order)
Expr.(ops)
end

"""
Expand Down Expand Up @@ -624,7 +546,7 @@ function link_contract_operators(d::SummationDecapode, con_dec_operators::Set{Sy
for op1_id in parts(d, :Op1)
op1_name = d[op1_id, :op1]
if isa(op1_name, AbstractArray)
computation = reverse!(map(x -> add_inplace_stub(x), op1_name))
computation = reverse!(map(x -> inplace(x), op1_name))
compute_key = join(computation, " * ")

computation_name = get(compute_to_name, compute_key, :Error)
Expand All @@ -636,7 +558,7 @@ function link_contract_operators(d::SummationDecapode, con_dec_operators::Set{Sy
expr_line = hook_LCO_inplace(computation_name, computation, stateeltype, code_target)
push!(contract_defs.args, expr_line)

expr_line = Expr(Symbol("="), computation_name, Expr(Symbol("->"), :x, Expr(:call, :*, add_inplace_stub(computation_name), :x)))
expr_line = Expr(Symbol("="), computation_name, Expr(Symbol("->"), :x, Expr(:call, :*, inplace(computation_name), :x)))
push!(contract_defs.args, expr_line)

curr_id += 1
Expand All @@ -651,7 +573,7 @@ end

# TODO: Allow user to overload these hooks with user-defined code_target
function hook_LCO_inplace(computation_name::Symbol, computation::Vector{Symbol}, float_type::DataType, ::CPUBackend)
return :($(add_inplace_stub(computation_name)) = $(Expr(:call, :*, computation...)))
return :($(inplace(computation_name)) = $(Expr(:call, :*, computation...)))
end

function generate_parentheses_multiply(list)
Expand All @@ -663,7 +585,7 @@ function generate_parentheses_multiply(list)
end

function hook_LCO_inplace(computation_name::Symbol, computation::Vector{Symbol}, float_type::DataType, ::CUDABackend)
return :($(add_inplace_stub(computation_name)) = $(generate_parentheses_multiply(computation)))
return :($(inplace(computation_name)) = $(generate_parentheses_multiply(computation)))
end

struct UnsupportedDimensionException <: Exception
Expand Down
6 changes: 3 additions & 3 deletions test/simulation_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const DOT_EQUALS =:.=
# Calling Code Tests #
######################

import Decapodes: UnaryCall, add_inplace_stub
import Decapodes: UnaryCall, inplace

@testset "Test UnaryCall" begin
# Test equality, basic operator
Expand All @@ -34,7 +34,7 @@ import Decapodes: UnaryCall, add_inplace_stub
@test Expr(UnaryCall(:.-, DOT_EQUALS, :x, :y)) == :(y .= .-x)

# Test broadcast equality, inplace non-matrix method
let inplace_op = add_inplace_stub(:⋆₁⁻¹)
let inplace_op = inplace(:⋆₁⁻¹)
@test Expr(UnaryCall(inplace_op, DOT_EQUALS, :x, :y)) == :($inplace_op(y, x))
end
end
Expand All @@ -46,7 +46,7 @@ import Decapodes: BinaryCall
@test Expr(BinaryCall(:F, EQUALS, :x, :y, :z)) == :(z = F(x, y))

# Test broadcast equality, inplace operator
let inplace_operator = add_inplace_stub(:F)
let inplace_operator = inplace(:F)
@test Expr(BinaryCall(inplace_operator, DOT_EQUALS, :x, :y, :z)) == :($inplace_operator(z, x, y))
end
end
Expand Down
Loading