Skip to content

Commit

Permalink
Remove Dict-centric code
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Dec 6, 2024
1 parent 7dc544c commit dd44b57
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 81 deletions.
121 changes: 41 additions & 80 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ abstract type AbstractGenerationTarget end

abstract type CPUBackend <: AbstractGenerationTarget end
abstract type CUDABackend <: AbstractGenerationTarget end
# TODO: Test that AbstractGenerationTargets are user-extendable.

struct CPUTarget <: CPUBackend end
struct CUDATarget <: CUDABackend end
Expand Down Expand Up @@ -224,13 +225,13 @@ end
Base.showerror(io::IO, e::InvalidCodeTargetException) = print(io, "Provided code target $(e.code_target) is not yet supported in simulations")

"""
compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Vector{Symbol}, code_target::AbstractGenerationTarget)
This creates the symbol to function linking for the simulation output. Those run through the `default_dec` backend
expect both an in-place and an out-of-place variant in that order. User defined operations only support out-of-place.
"""
function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
defined_ops = deepcopy(con_dec_operators)
function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Vector{Symbol}, code_target::AbstractGenerationTarget)
defined_ops = Set(con_dec_operators)

defs = quote end

Expand Down Expand Up @@ -555,29 +556,13 @@ function hook_PPVA_data_handle!(cache_exprs::Vector{Expr}, alloc_vec::AllocVecCa
end

"""
resolve_types_compiler!(d::SummationDecapode)
convert_cs_ps_to_infer!(d::SummationDecapode)
Converts `Constant` and `Parameter` types to `infer` since this is essentially what they are
to the compiler.
Convert `Constant` and `Parameter` types to `infer`.
"""
function resolve_types_compiler!(d::SummationDecapode)
d[:type] = map(d[:type]) do x
if x == :Constant || x == :Parameter
return :infer
end
return x
end
end

"""
replace_names_compiler!(d::SummationDecapode)
This makes easy function name conversions in the Decapode
"""
function replace_names_compiler!(d::SummationDecapode)
dec_op1 = Pair{Symbol, Any}[]
dec_op2 = Pair{Symbol, Symbol}[(:₀₀ => :.*)]
replace_names!(d, dec_op1, dec_op2)
function convert_cs_ps_to_infer!(d::SummationDecapode)
cs_ps = incident(d, [:Constant, :Parameter], :type)
d[vcat(cs_ps...), :type] = :infer
end

# TODO: This should be extended to accept user rules
Expand All @@ -603,64 +588,42 @@ Collects all DEC operators that are concrete matrices.
"""
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
for op_name in vcat(d[:op1], d[:op2])
if op_name in optimizable_dec_operators
push!(dec_matrices, op_name)
end
op_name in optimizable_dec_operators && push!(dec_matrices, op_name)
end
end

""" link_contract_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)
""" link_contracted_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
Collects arrays of DEC matrices together, replaces the array with a generated function name and computes the contracted multiplication
Emit code to pre-multiply unique sequences of matrix operations, and rename corresponding operations.
"""
function link_contract_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)
compute_to_name = Dict()

for op1_id in parts(d, :Op1)
op1_name = d[op1_id, :op1]
if op1_name isa AbstractArray
# Pre-multiply the matrices.
# e.g. var"GenSim-M_d₀" * var"GenSim-M_⋆₀⁻¹"
computation = reverse!(map(x -> add_inplace_stub(x), op1_name))
compute_key = join(computation, " * ")

if compute_key keys(compute_to_name)
computation_name = add_stub(Symbol("GenSim-ConMat"), Symbol(length(compute_to_name)))
compute_to_name[compute_key] = computation_name
push!(con_dec_operators, computation_name)

# Pre-multiply the matrices.
# e.g. var"GenSim-M_GenSim-ConMat_1" = var"GenSim-M_d₀" * var"GenSim-M_⋆₀⁻¹"
push!(contract_defs.args,
hook_LCO_inplace(computation_name, computation, stateeltype, code_target))

# Define a function which multiplies by the given matrix.
# e.g. var"GenSim-ConMat_1" = (x->var"GenSim-M_GenSim-ConMat_1" * x)
push!(contract_defs.args,
Expr(Symbol("="), computation_name, Expr(Symbol("->"), :x, Expr(:call, :*, add_inplace_stub(computation_name), :x))))
end
function link_contracted_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Vector{Symbol}, code_target::AbstractGenerationTarget)
chain_idxs = findall(x -> x isa AbstractArray, d[:op1])

d[op1_id, :op1] = compute_to_name[compute_key]
end
for (i, chain) in enumerate(unique(d[chain_idxs, :op1]))
LHS = add_stub(Symbol("GenSim-ConMat"), Symbol(i-1))
RHS = reverse!(add_inplace_stub.(chain))

push!(con_dec_operators, LHS)
push!(contract_defs.args, mat_def_expr(LHS, RHS, code_target), mat_mul_func_expr(LHS))
d[findall(==(chain), d[:op1]), :op1] = LHS
end
end

# TODO: Allow user to overload these hooks with user-defined code_target
function hook_LCO_inplace(computation_name::Symbol, computation::Vector{Symbol}, float_type::DataType, ::CPUBackend)
return :($(add_inplace_stub(computation_name)) = $(Expr(:call, :*, computation...)))
end
# Given the name of a matrix, return an Expr that multiplies by that matrix.
mat_mul_func_expr(mat_name) =
:($mat_name = x -> $(add_inplace_stub(mat_name)) * x)

function generate_parentheses_multiply(list)
if length(list) == 1
return list[1]
else
return Expr(:call, :*, generate_parentheses_multiply(list[1:end-1]), list[end])
end
end
# Given the name and factors of a matrix, return an Expr that defines that matrix.
mat_def_expr(computation_name::Symbol, factors::Vector{Symbol}, ::CPUBackend) =
:($(add_inplace_stub(computation_name)) = *($(factors...)))

function hook_LCO_inplace(computation_name::Symbol, computation::Vector{Symbol}, float_type::DataType, ::CUDABackend)
return :($(add_inplace_stub(computation_name)) = $(generate_parentheses_multiply(computation)))
end
nested_mul(factors) =

Check warning on line 620 in src/simulation.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation.jl#L620

Added line #L620 was not covered by tests
length(factors) == 1 ?
list[begin] :
Expr(:call, :*, nested_mul(factors[begin:end-1]), factors[end])

mat_def_expr(computation_name::Symbol, factors::Vector{Symbol}, ::CUDABackend) =

Check warning on line 625 in src/simulation.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation.jl#L625

Added line #L625 was not covered by tests
:($(add_inplace_stub(computation_name)) = $(nested_mul(factors)))

struct UnsupportedDimensionException <: Exception
dim::Int
Expand Down Expand Up @@ -722,33 +685,31 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
vars = get_vars_code(gen_d, input_vars, stateeltype, code_target)
tars = set_tanvars_code(gen_d, code_target)

# We need to run this after we grab the constants and parameters out
infer_overload_compiler!(gen_d, dimension)
resolve_types_compiler!(gen_d)
convert_cs_ps_to_infer!(gen_d)
infer_overload_compiler!(gen_d, dimension)

# This should probably be followed by an expand_operators
replace_names_compiler!(gen_d)
# TODO: This should probably be followed by an expand_operators.
replace_names!(gen_d, Pair{Symbol, Any}[], Pair{Symbol, Symbol}[(:₀₀ => :.*)])
open_operators!(gen_d, dimension = dimension)
infer_overload_compiler!(gen_d, dimension)

# This will generate all of the fundemental DEC operators present
# Generate necessary fundamental DEC operators.
optimizable_dec_operators = Set([:₀, :₁, :₂, :₀⁻¹, :₂⁻¹,
:d₀, :d₁, :dual_d₀, :d̃₀, :dual_d₁, :d̃₁,
:avg₀₁])
extra_dec_operators = Set([:₁⁻¹, :₀₁, :₁₀, :₁₁, :₀₂, :₂₀])

init_dec_matrices!(gen_d, dec_matrices, union(optimizable_dec_operators, extra_dec_operators))

# This contracts matrices together into a single matrix
# Pre-multiply sequences of matrices.
contract_defs = quote end
contracted_dec_operators = Set{Symbol}()
contracted_dec_operators = Symbol[]
if contract
contract_operators!(gen_d, white_list = optimizable_dec_operators)
link_contract_operators!(gen_d, contract_defs, contracted_dec_operators, stateeltype, code_target)
link_contracted_operators!(gen_d, contract_defs, contracted_dec_operators, code_target)
end


union!(optimizable_dec_operators, contracted_dec_operators, extra_dec_operators)

# Compilation of the simulation
Expand Down
2 changes: 1 addition & 1 deletion test/simulation_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ import Decapodes: compile_env, InvalidCodeTargetException
# Test that error throws on unknown code target
let d = @decapode begin end
struct BadTarget <: AbstractGenerationTarget end
@test_throws InvalidCodeTargetException compile_env(d, [:test], Set{Symbol}(), BadTarget())
@test_throws InvalidCodeTargetException compile_env(d, [:test], Symbol[], BadTarget())
end

end
Expand Down

0 comments on commit dd44b57

Please sign in to comment.