From 55d2972c175ad704b4991d351e2ee2e0117680f4 Mon Sep 17 00:00:00 2001 From: Luke Morris <70283489+lukem12345@users.noreply.github.com> Date: Fri, 19 Jul 2024 22:10:02 -0400 Subject: [PATCH] Check black list when contracting operators (#26) --- src/acset.jl | 62 +++++++++++++++++++++++++++++------------------- test/language.jl | 15 ++++++++++++ 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 2dc7860..9665afd 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -285,20 +285,26 @@ function expand_operators(d::SummationDecapode) return e end -""" function contract_operators(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symbol}()) +""" function contract_operators(d::SummationDecapode; white_list::Set{Symbol} = Set{Symbol}(), black_list::Set{Symbol} = Set{Symbol}()) Find chains of Op1s in the given Decapode, and replace them with a single Op1 with a vector of function names. After this process, -all Vars that are not a part of any computation are removed. +all Vars that are not a part of any computation are removed. If a +white list is provided, only chain those operators. If a black list +is provided, do not chain those operators. """ -function contract_operators(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symbol}()) +function contract_operators(d::SummationDecapode; + white_list::Set{Symbol} = Set{Symbol}(), + black_list::Set{Symbol} = Set{Symbol}()) e = expand_operators(d) - contract_operators!(e, allowable_ops = allowable_ops) + contract_operators!(e, white_list=white_list, black_list=black_list) #return e end -function contract_operators!(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symbol}()) - chains = find_chains(d, allowable_ops = allowable_ops) +function contract_operators!(d::SummationDecapode; + white_list::Set{Symbol} = Set{Symbol}(), + black_list::Set{Symbol} = Set{Symbol}()) + chains = find_chains(d, white_list=white_list, black_list=black_list) filter!(x -> length(x) != 1, chains) for chain in chains add_part!(d, :Op1, src=d[:src][first(chain)], tgt=d[:tgt][last(chain)], op1=Vector{Symbol}(d[:op1][chain])) @@ -327,28 +333,32 @@ function remove_neighborless_vars!(d::SummationDecapode) d end -""" - function find_chains(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symbol}()) +""" function find_chains(d::SummationDecapode; white_list::Set{Symbol} = Set{Symbol}(), black_list::Set{Symbol} = Set{Symbol}()) Find chains of Op1s in the given Decapode. A chain ends when the target of the last Op1 is part of an Op2 or sum, or is a target -of multiple Op1s. Only operators with names included in the -allowable_ops set are allowed to be contracted. If the set is -empty then all operators are allowed. +of multiple Op1s. If a white list is provided, only chain those +operators. If a black list is provided, do not chain those operators. """ -function find_chains(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symbol}()) +function find_chains(d::SummationDecapode; + white_list::Set{Symbol} = Set{Symbol}(), + black_list::Set{Symbol} = Set{Symbol}()) chains = [] visited = falses(nparts(d, :Op1)) # TODO: Re-write this without two reduce-vcats. - chain_starts = unique(reduce(vcat, reduce(vcat, - [incident(d, Vector{Int64}(filter(i -> !isnothing(i), infer_states(d))), :src), - incident(d, d[:res], :src), - incident(d, d[:sum], :src)]))) - - if(!isempty(allowable_ops)) - filter!(x -> d[x, :op1] ∈ allowable_ops, chain_starts) - end - + rvrv(x) = reduce(vcat, reduce(vcat, x)) + chain_starts = unique(rvrv( + [incident(d, Vector{Int64}(filter(i -> !isnothing(i), infer_states(d))), :src), + incident(d, d[:res], :src), + incident(d, d[:sum], :src), + incident(d, d[collect(Iterators.flatten(incident(d, collect(black_list), :op1))), :tgt], :src)])) + + passes_white_list(x) = isempty(white_list) ? true : x ∈ white_list + passes_black_list(x) = x ∉ black_list + + filter!(x -> passes_white_list(d[x, :op1]), chain_starts) + filter!(x -> passes_black_list(d[x, :op1]), chain_starts) + s = Stack{Int64}() foreach(x -> push!(s, x), chain_starts) while !isempty(s) @@ -368,11 +378,14 @@ function find_chains(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symb is_tgt_of_many_ops(d, tgt) || !isempty(incident(d, tgt, :sum)) || !isempty(incident(d, tgt, :summand)) || - (!isempty(allowable_ops) && d[only(next_op1s), :op1] ∉ allowable_ops)) + !passes_white_list(d[only(next_op1s), :op1]) || + !passes_black_list(d[only(next_op1s), :op1])) # Terminate chain. append!(chains, [curr_chain]) - for next_op1 in next_op1s - visited[next_op1] || (!isempty(allowable_ops) && d[only(next_op1s), :op1] ∉ allowable_ops) || push!(s, next_op1) + for op1 in next_op1s + if !visited[op1] && passes_white_list(d[op1, :op1]) && passes_black_list(d[op1, :op1]) + push!(s, op1) + end end break end @@ -382,6 +395,7 @@ function find_chains(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symb end return chains end + function add_constant!(d::AbstractNamedDecapode, k::Symbol) return add_part!(d, :Var, type=:Constant, name=k) end diff --git a/test/language.jl b/test/language.jl index a90b088..3e3ab4e 100644 --- a/test/language.jl +++ b/test/language.jl @@ -1283,6 +1283,21 @@ end op1 = [:⋆₂, :∂ₜ, :d₁] end + + t14_orig = @decapode begin + C == a(b(c(d(e(f(g(D))))))) + end + t14_contracted = contract_operators(t14_orig, + black_list=Set([:d])) + @test issetequal(t14_contracted[:op1], [:d, [:g, :f, :e], [:c, :b, :a]]) + + t15_orig = @decapode begin + C == a(b(c(d(e(f(g(D))))))) + end + t15_contracted = contract_operators(t15_orig, + white_list=Set([:a, :b, :c]), + black_list=Set([:d])) + @test issetequal(t15_contracted[:op1], [:g, :f, :e, :d, [:c, :b, :a]]) end @testset "ASCII & Vector Calculus Operators" begin