Skip to content

Commit

Permalink
Encapsulate, remove dead code, simplify control flow
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Dec 7, 2024
1 parent dd44b57 commit f48ac23
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 120 deletions.
Binary file added src/.simulation.jl.swp
Binary file not shown.
176 changes: 57 additions & 119 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,33 +132,6 @@ function hook_AVC_caching(c::AllocVecCall, resolved_form::Symbol, ::CUDABackend)
:($(c.name) = CuVector{$(c.T)}(undef, nparts(mesh, $(QuoteNode(resolved_form)))))
end

"""
compile_var(alloc_vectors::Vector{AllocVecCall})
This creates the vector allocations that will be used by the simulation body for in-place operations.
"""
function compile_var(alloc_vectors::Vector{AllocVecCall})
return quote $(Expr.(alloc_vectors)...) end
end

#= function get_form_number(d::SummationDecapode, var_id::Int)
type = d[var_id, :type]
if type == :Form0
return 0
elseif type == :Form1
return 1
elseif type == :Form2
return 2
end
return -1
end
# ! WARNING: This may not work if names are not unique, use above instead
function get_form_number(d::SummationDecapode, var_name::Symbol)
var_id = first(incident(d, var_name, :name))
return get_form_number(d, var_id)
end =#

# TODO: This should be edited when we replace types as symbols with types as Julia types
function is_form(d::SummationDecapode, var_id::Int)
type = d[var_id, :type]
Expand Down Expand Up @@ -224,49 +197,31 @@ end

Base.showerror(io::IO, e::InvalidCodeTargetException) = print(io, "Provided code target $(e.code_target) is not yet supported in simulations")

"""
compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Vector{Symbol}, code_target::AbstractGenerationTarget)
# TODO: This function should be handled with dispatch.
""" compile_env(d::SummationDecapode, basic_dec_ops::Vector{Symbol}, contracted_ops::Vector{Symbol}, code_target::AbstractGenerationTarget)
This creates the symbol to function linking for the simulation output. Those run through the `default_dec` backend
expect both an in-place and an out-of-place variant in that order. User defined operations only support out-of-place.
Emit code to define functions given operator Symbols.
Default operations return a tuple of an in-place and an out-of-place function. User-defined operations return an out-of-place function.
"""
function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Vector{Symbol}, code_target::AbstractGenerationTarget)
defined_ops = Set(con_dec_operators)
function compile_env(d::SummationDecapode, basic_dec_ops::Set{Symbol}, contracted_ops::Vector{Symbol}, code_target::AbstractGenerationTarget)
default_generation = @match code_target begin
::CPUBackend => :default_dec_matrix_generate
::CUDABackend => :default_dec_cu_matrix_generate
_ => throw(InvalidCodeTargetException(code_target))
end

defs = quote end

for op in dec_matrices
op in defined_ops && continue

for op in setdiff(basic_dec_ops d[:op1] d[:op2], contracted_ops [DerivOp] ARITHMETIC_OPS)
quote_op = QuoteNode(op)
mat_op = add_stub(GENSIM_INPLACE_STUB, op)

# TODO: Add support for user-defined code targets
default_generation = @match code_target begin
::CPUBackend => :default_dec_matrix_generate
::CUDABackend => :default_dec_cu_matrix_generate
_ => throw(InvalidCodeTargetException(code_target))
end

def = :(($mat_op, $op) = $(default_generation)(mesh, $quote_op, hodge))
def = op in basic_dec_ops ?
:(($(add_inplace_stub(op)), $op) = $(default_generation)(mesh, $quote_op, hodge)) :
:($op = operators(mesh, $quote_op))
push!(defs.args, def)

push!(defined_ops, op)
end

# Add in user-defined operations
for op in vcat(d[:op1], d[:op2])
if op == DerivOp || op in defined_ops || op in ARITHMETIC_OPS
continue
end
ops = QuoteNode(op)
def = :($op = operators(mesh, $ops))
push!(defs.args, def)

push!(defined_ops, op)
end

return defs
defs
end

struct AmbiguousNameException <: Exception
Expand Down Expand Up @@ -372,15 +327,16 @@ const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+,
:.= => :.=)

"""
compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
compile(d::SummationDecapode, inputs::Vector{Symbol}, inplace_dec_ops::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
Function that compiles the computation body. `d` is the input Decapode, `inputs` is a vector of state variables and literals,
`alloc_vec` should be empty when passed in, `optimizable_dec_operators` is a collection of all DEC operator symbols that can use special
`inplace_dec_ops` is a collection of all DEC operator symbols that can use special
in-place methods, `dimension` is the dimension of the problem (usually 1 or 2), `stateeltype` is the type of the state elements
(usually Float32 or Float64), `code_target` determines what architecture the code is compiled for (either CPU or CUDA), and `preallocate`
which is set to `true` by default and determines if intermediate results can be preallocated..
"""
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
function compile(d::SummationDecapode, inputs::Vector{Symbol}, inplace_dec_ops::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
alloc_vectors = Vector{AllocVecCall}()
# Get the Vars of the inputs (probably state Vars).
visited_Var = falses(nparts(d, :Var))

Expand Down Expand Up @@ -417,7 +373,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve

# TODO: Check to see if this is a DEC operator
if preallocate && is_form(d, t)
if operator in optimizable_dec_operators
if operator in inplace_dec_ops
equality = PROMOTE_ARITHMETIC_MAP[equality]
operator = add_stub(GENSIM_INPLACE_STUB, operator)
push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target))
Expand Down Expand Up @@ -463,7 +419,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end
elseif operator in optimizable_dec_operators
elseif operator in inplace_dec_ops
operator = add_stub(GENSIM_INPLACE_STUB, operator)
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
Expand Down Expand Up @@ -517,7 +473,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
end
end

eq_exprs = Expr.(op_order)
Expr.(op_order), alloc_vectors
end

"""
Expand Down Expand Up @@ -581,32 +537,24 @@ function infer_overload_compiler!(d::SummationDecapode, dimension::Int)
end
end

"""
init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
Collects all DEC operators that are concrete matrices.
"""
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
for op_name in vcat(d[:op1], d[:op2])
op_name in optimizable_dec_operators && push!(dec_matrices, op_name)
end
end

""" link_contracted_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
""" link_contracted_operators!(d::SummationDecapode, code_target::AbstractGenerationTarget)
Emit code to pre-multiply unique sequences of matrix operations, and rename corresponding operations.
"""
function link_contracted_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Vector{Symbol}, code_target::AbstractGenerationTarget)
function link_contracted_operators!(d::SummationDecapode, code_target::AbstractGenerationTarget)
contracted_defs = quote end
contracted_ops = Symbol[]
chain_idxs = findall(x -> x isa AbstractArray, d[:op1])

for (i, chain) in enumerate(unique(d[chain_idxs, :op1]))
LHS = add_stub(Symbol("GenSim-ConMat"), Symbol(i-1))
RHS = reverse!(add_inplace_stub.(chain))

push!(con_dec_operators, LHS)
push!(contract_defs.args, mat_def_expr(LHS, RHS, code_target), mat_mul_func_expr(LHS))
push!(contracted_ops, LHS)
push!(contracted_defs.args, mat_def_expr(LHS, RHS, code_target), mat_mul_func_expr(LHS))
d[findall(==(chain), d[:op1]), :op1] = LHS
end
contracted_defs, contracted_ops
end

# Given the name of a matrix, return an Expr that multiplies by that matrix.
Expand Down Expand Up @@ -664,69 +612,59 @@ to operator mappings to return a simulator that can be used to solve the represe
`multigrid`: Enables multigrid methods during code generation. If `true`, then the function produced by `gensim` will expect a `PrimalGeometricMapSeries`. (Defaults to `false`)
"""
function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget(), preallocate::Bool = true, contract::Bool = true, multigrid::Bool = false)

(dimension == 1 || dimension == 2) ||
throw(UnsupportedDimensionException(dimension))

(stateeltype == Float32 || stateeltype == Float64) ||
throw(UnsupportedStateeltypeException(stateeltype))

# Explicit copy for safety
gen_d = deepcopy(user_d)
d = deepcopy(user_d)

recognize_types(gen_d)
recognize_types(d)

# Makes copy
gen_d = expand_operators(gen_d)

dec_matrices = Vector{Symbol}()
alloc_vectors = Vector{AllocVecCall}()
d = expand_operators(d)

vars = get_vars_code(gen_d, input_vars, stateeltype, code_target)
tars = set_tanvars_code(gen_d, code_target)
vars = get_vars_code(d, input_vars, stateeltype, code_target)
tars = set_tanvars_code(d, code_target)

infer_overload_compiler!(gen_d, dimension)
convert_cs_ps_to_infer!(gen_d)
infer_overload_compiler!(gen_d, dimension)
infer_overload_compiler!(d, dimension)
convert_cs_ps_to_infer!(d)
infer_overload_compiler!(d, dimension)

# TODO: This should probably be followed by an expand_operators.
replace_names!(gen_d, Pair{Symbol, Any}[], Pair{Symbol, Symbol}[(:₀₀ => :.*)])
open_operators!(gen_d, dimension = dimension)
infer_overload_compiler!(gen_d, dimension)
# XXX: expand_operators should be called if any replacement is a chain of operations.
replace_names!(d, Pair{Symbol, Any}[], Pair{Symbol, Symbol}[(:₀₀ => :.*)])
open_operators!(d, dimension = dimension)
infer_overload_compiler!(d, dimension)

# Generate necessary fundamental DEC operators.
optimizable_dec_operators = Set([:₀, :₁, :₂, :₀⁻¹, :₂⁻¹,
:d₀, :d₁, :dual_d₀, :d̃₀, :dual_d₁, :d̃₁,
:avg₀₁])
extra_dec_operators = Set([:₁⁻¹, :₀₁, :₁₀, :₁₁, :₀₂, :₂₀])
# Determine basic DEC operators to generate.
matrix_dec_ops = Set([:₀, :₁, :₂, :₀⁻¹, :₂⁻¹, :d₀, :d₁, :dual_d₀, :d̃₀, :dual_d₁, :d̃₁, :avg₀₁])
non_matrix_dec_ops = Set([:₁⁻¹, :₀₁, :₁₀, :₁₁, :₀₂, :₂₀])
dec_ops = matrix_dec_ops non_matrix_dec_ops

init_dec_matrices!(gen_d, dec_matrices, union(optimizable_dec_operators, extra_dec_operators))
basic_dec_ops = Set{Symbol}(dec_ops (d[:op1] d[:op2]))

# Pre-multiply sequences of matrices.
contract_defs = quote end
contracted_dec_operators = Symbol[]
if contract
contract_operators!(gen_d, white_list = optimizable_dec_operators)
link_contracted_operators!(gen_d, contract_defs, contracted_dec_operators, code_target)
end
contract && contract_operators!(d, white_list = matrix_dec_ops)
contracted_defs, contracted_ops = link_contracted_operators!(d, code_target)

union!(optimizable_dec_operators, contracted_dec_operators, extra_dec_operators)
inplace_dec_ops = dec_ops contracted_ops

# Compilation of the simulation
equations = compile(gen_d, input_vars, alloc_vectors, optimizable_dec_operators, dimension, stateeltype, code_target, preallocate)
equations, alloc_vectors = compile(d, input_vars, inplace_dec_ops, dimension, stateeltype, code_target, preallocate)
data = post_process_vector_allocs(alloc_vectors, code_target)

func_defs = compile_env(gen_d, dec_matrices, contracted_dec_operators, code_target)
vect_defs = compile_var(alloc_vectors)
func_defs = compile_env(d, basic_dec_ops, contracted_ops, code_target)
vect_defs = quote $(Expr.(alloc_vectors)...) end

prologue = quote end
multigrid && push!(prologue.args, :(mesh = finest_mesh(mesh)))
multigrid_defs = quote end
multigrid && push!(multigrid_defs.args, :(mesh = finest_mesh(mesh)))

quote
(mesh, operators, hodge=GeometricHodge()) -> begin
$func_defs
$contract_defs
$prologue
$contracted_defs
$multigrid_defs
$vect_defs
f(__du__, __u__, __p__, __t__) = begin
$vars
Expand Down
2 changes: 1 addition & 1 deletion test/simulation_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ import Decapodes: compile_env, InvalidCodeTargetException
# Test that error throws on unknown code target
let d = @decapode begin end
struct BadTarget <: AbstractGenerationTarget end
@test_throws InvalidCodeTargetException compile_env(d, [:test], Symbol[], BadTarget())
@test_throws InvalidCodeTargetException compile_env(d, Set{Symbol}([:test]), Symbol[], BadTarget())
end

end
Expand Down

0 comments on commit f48ac23

Please sign in to comment.