From f819829e1ffccf31c9d99f2d78d0793785c02fd4 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 5 Dec 2024 16:35:53 -0700 Subject: [PATCH] Document link_contract_operators! and remove counter --- src/simulation.jl | 28 +++++++++++++++------------- test/simulation.jl | 16 ++++++++++++---- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/simulation.jl b/src/simulation.jl index 77e82a44..92d5de92 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -615,30 +615,32 @@ Collects arrays of DEC matrices together, replaces the array with a generated fu """ function link_contract_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget) compute_to_name = Dict() - curr_id = 1 for op1_id in parts(d, :Op1) op1_name = d[op1_id, :op1] - if isa(op1_name, AbstractArray) + if op1_name isa AbstractArray + # Pre-multiply the matrices. + # e.g. var"GenSim-M_d₀" * var"GenSim-M_⋆₀⁻¹" computation = reverse!(map(x -> add_inplace_stub(x), op1_name)) compute_key = join(computation, " * ") - computation_name = get(compute_to_name, compute_key, :Error) - if computation_name == :Error - computation_name = add_stub(Symbol("GenSim-ConMat"), Symbol(curr_id)) - get!(compute_to_name, compute_key, computation_name) + if compute_key ∉ keys(compute_to_name) + computation_name = add_stub(Symbol("GenSim-ConMat"), Symbol(length(compute_to_name))) + compute_to_name[compute_key] = computation_name push!(con_dec_operators, computation_name) - expr_line = hook_LCO_inplace(computation_name, computation, stateeltype, code_target) - push!(contract_defs.args, expr_line) + # Pre-multiply the matrices. + # e.g. var"GenSim-M_GenSim-ConMat_1" = var"GenSim-M_d₀" * var"GenSim-M_⋆₀⁻¹" + push!(contract_defs.args, + hook_LCO_inplace(computation_name, computation, stateeltype, code_target)) - expr_line = Expr(Symbol("="), computation_name, Expr(Symbol("->"), :x, Expr(:call, :*, add_inplace_stub(computation_name), :x))) - push!(contract_defs.args, expr_line) - - curr_id += 1 + # Define a function which multiplies by the given matrix. + # e.g. var"GenSim-ConMat_1" = (x->var"GenSim-M_GenSim-ConMat_1" * x) + push!(contract_defs.args, + Expr(Symbol("="), computation_name, Expr(Symbol("->"), :x, Expr(:call, :*, add_inplace_stub(computation_name), :x)))) end - d[op1_id, :op1] = computation_name + d[op1_id, :op1] = compute_to_name[compute_key] end end end diff --git a/test/simulation.jl b/test/simulation.jl index aa678c9e..00afbf01 100644 --- a/test/simulation.jl +++ b/test/simulation.jl @@ -364,30 +364,38 @@ end # Testing simple contract operations single_contract = @decapode begin - (A,C)::Form0 - (D)::Form2 + (A,C,E)::Form0 + (D,F)::Form2 B == ∂ₜ(A) D == ∂ₜ(C) B == ⋆(⋆(A)) D == d(d(C)) + F == d(d(E)) end @test 4 == count_contractions(single_contract) @test 0 == count_contractions(gensim(single_contract; contract=false)) + f = gensim(single_contract) + @test f.args[2].args[2].args[5].args[[2,4]] == [ + :(var"GenSim-M_GenSim-ConMat_0" = var"GenSim-M_d₁" * var"GenSim-M_d₀"), + :(var"GenSim-M_GenSim-ConMat_1" = var"GenSim-M_⋆₀⁻¹" * var"GenSim-M_⋆₀")] + sim = eval(gensim(single_contract)) f = sim(earth, default_dec_generate) A = 2 * ones(nv(earth)) C = ones(nv(earth)) - u = ComponentArray(A=A, C=C) - du = ComponentArray(A=zeros(nv(earth)), C=zeros(ntriangles(earth))) + E = ones(nv(earth)) + u = ComponentArray(A=A, C=C, E=E) + du = ComponentArray(A=zeros(nv(earth)), C=zeros(ntriangles(earth)), E=zeros(ntriangles(earth))) constants_and_parameters = () f(du, u, constants_and_parameters, 0) @test du.A ≈ 2 * ones(nv(earth)) @test du.C == zeros(ntriangles(earth)) + @test du.E == zeros(ntriangles(earth)) # Testing contraction interrupted by summation contract_with_summation = @decapode begin