Skip to content

Commit

Permalink
Add keyword argument for contract operators
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Dec 5, 2024
1 parent cc035d8 commit 7318ca8
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,14 +609,11 @@ function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol},
end
end

"""
link_contract_operators(d::SummationDecapode, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)
""" link_contract_operators!(d::SummationDecapode, contract_defs, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)
Collects arrays of DEC matrices together, replaces the array with a generated function name and computes the contracted multiplication
"""
function link_contract_operators(d::SummationDecapode, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)

contract_defs = quote end
function link_contract_operators!(d::SummationDecapode, contract_defs, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)

compute_to_name = Dict()
curr_id = 1
Expand Down Expand Up @@ -700,9 +697,11 @@ to operator mappings to return a simulator that can be used to solve the represe
`preallocate`: Enables(`true`)/disables(`false`) pre-allocated caches for intermediate computations. Some functions, such as those that determine Jacobian sparsity patterns, or perform auto-differentiation, may require this to be disabled. (Defaults to `true`)
`contract`: Enables(`true`)/disables(`false`) pre-computation of matrix-matrix multiplications for chains of such operators. This feature can interfere with certain auto-differentiation methods, in which case this can be disabled. (Defaults to `true`)
`multigrid`: Enables multigrid methods during code generation. If `true`, then the function produced by `gensim` will expect a `PrimalGeometricMapSeries`. (Defaults to `false`)
"""
function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget(), preallocate::Bool = true, multigrid::Bool = false)
function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget(), preallocate::Bool = true, contract::Bool = true, multigrid::Bool = false)

(dimension == 1 || dimension == 2) ||
throw(UnsupportedDimensionException(dimension))
Expand Down Expand Up @@ -743,9 +742,12 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
init_dec_matrices!(gen_d, dec_matrices, union(optimizable_dec_operators, extra_dec_operators))

# This contracts matrices together into a single matrix
contract_defs = quote end
contracted_dec_operators = Set{Symbol}()
contract_operators!(gen_d, white_list = optimizable_dec_operators)
cont_defs = link_contract_operators(gen_d, contracted_dec_operators, stateeltype, code_target)
if contract
contract_operators!(gen_d, white_list = optimizable_dec_operators)
cont_defs = link_contract_operators!(gen_d, contract_defs, contracted_dec_operators, stateeltype, code_target)
end

union!(optimizable_dec_operators, contracted_dec_operators, extra_dec_operators)

Expand Down

0 comments on commit 7318ca8

Please sign in to comment.