Skip to content

Commit

Permalink
Fix contraction tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Dec 5, 2024
1 parent 7318ca8 commit 23bc80c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
12 changes: 5 additions & 7 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,11 @@ function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol},
end
end

""" link_contract_operators!(d::SummationDecapode, contract_defs, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)
""" link_contract_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)
Collects arrays of DEC matrices together, replaces the array with a generated function name and computes the contracted multiplication
"""
function link_contract_operators!(d::SummationDecapode, contract_defs, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)

function link_contract_operators!(d::SummationDecapode, contract_defs::Expr, con_dec_operators::Set{Symbol}, stateeltype::DataType, code_target::AbstractGenerationTarget)
compute_to_name = Dict()
curr_id = 1

Expand Down Expand Up @@ -642,8 +641,6 @@ function link_contract_operators!(d::SummationDecapode, contract_defs, con_dec_o
d[op1_id, :op1] = computation_name
end
end

contract_defs
end

# TODO: Allow user to overload these hooks with user-defined code_target
Expand Down Expand Up @@ -746,9 +743,10 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
contracted_dec_operators = Set{Symbol}()
if contract
contract_operators!(gen_d, white_list = optimizable_dec_operators)
cont_defs = link_contract_operators!(gen_d, contract_defs, contracted_dec_operators, stateeltype, code_target)
link_contract_operators!(gen_d, contract_defs, contracted_dec_operators, stateeltype, code_target)
end


union!(optimizable_dec_operators, contracted_dec_operators, extra_dec_operators)

# Compilation of the simulation
Expand All @@ -764,7 +762,7 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
quote
(mesh, operators, hodge=GeometricHodge()) -> begin
$func_defs
$cont_defs
$contract_defs

Check warning on line 765 in src/simulation.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation.jl#L765

Added line #L765 was not covered by tests
$prologue
$vect_defs
f(__du__, __u__, __p__, __t__) = begin
Expand Down
Binary file added test/.simulation.jl.swp
Binary file not shown.
27 changes: 13 additions & 14 deletions test/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,13 @@ end

@testset "Gensim Transformations" begin

function checkForContractionInGensim(d::SummationDecapode)
results = []
block = gensim(d).args[2].args[2].args[5]
for line in 2:length(block.args)
push!(results, block.args[line].args[1])
end

return results
function count_contractions(e::Expr)
block = e.args[2].args[2].args[5]
length(block.args) - 1
end

count_contractions(d::SummationDecapode) = count_contractions(gensim(d))

begin
primal_earth = loadmesh(Icosphere(1))
nploc = argmax(x -> x[3], primal_earth[:point])
Expand Down Expand Up @@ -376,7 +373,9 @@ end
B == ((A))
D == d(d(C))
end
@test 4 == length(checkForContractionInGensim(single_contract))
@test 4 == count_contractions(single_contract)

@test 0 == count_contractions(gensim(single_contract; contract=false))

sim = eval(gensim(single_contract))
f = sim(earth, default_dec_generate)
Expand All @@ -403,7 +402,7 @@ end

D == d(d(C))
end
@test 4 == length(checkForContractionInGensim(single_contract))
@test 4 == count_contractions(contract_with_summation)

sim = eval(gensim(contract_with_summation))
f = sim(earth, default_dec_generate)
Expand All @@ -430,7 +429,7 @@ end

D == d(d(C))
end
@test 4 == length(checkForContractionInGensim(single_contract))
@test 4 == count_contractions(contract_with_op2)

for prealloc in [false, true]
let sim = eval(gensim(contract_with_op2, preallocate = prealloc))
Expand All @@ -456,7 +455,7 @@ end
B == A * A
D == ((B))
end
@test 4 == length(checkForContractionInGensim(single_contract))
@test 2 == count_contractions(later_contraction)

sim = eval(gensim(later_contraction))
f = sim(earth, default_dec_generate)
Expand All @@ -476,7 +475,7 @@ end
D == ∂ₜ(A)
D == d(A)
end
@test 0 == length(checkForContractionInGensim(no_contraction))
@test 0 == count_contractions(no_contraction)

sim = eval(gensim(no_contraction))
f = sim(earth, default_dec_generate)
Expand All @@ -496,7 +495,7 @@ end
D == ∂ₜ(A)
D == d(k(A))
end
@test 0 == length(checkForContractionInGensim(no_unallowed))
@test 0 == count_contractions(no_unallowed)

sim = eval(gensim(no_unallowed))

Expand Down

0 comments on commit 23bc80c

Please sign in to comment.