Skip to content

Commit

Permalink
Update contract operators docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 authored May 1, 2024
1 parent b52aef3 commit 409a8fd
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,13 @@ function expand_operators(d::SummationDecapode)
return e
end

""" function contract_operators(d::SummationDecapode; allowable_ops::Set{Symbol} = Set{Symbol}())
""" function contract_operators(d::SummationDecapode; white_list::Set{Symbol} = Set{Symbol}(), black_list::Set{Symbol} = Set{Symbol}())
Find chains of Op1s in the given Decapode, and replace them with
a single Op1 with a vector of function names. After this process,
all Vars that are not a part of any computation are removed.
all Vars that are not a part of any computation are removed. If a
white list is provided, only chain those operators. If a black list
is provided, do not chain those operators.
"""
function contract_operators(d::SummationDecapode;
white_list::Set{Symbol} = Set{Symbol}(),
Expand Down Expand Up @@ -200,16 +202,12 @@ function remove_neighborless_vars!(d::SummationDecapode)
d
end

"""
function find_chains(d::SummationDecapode;
white_list::Set{Symbol} = Set{Symbol}(),
black_list::Set{Symbol} = Set{Symbol}())
""" function find_chains(d::SummationDecapode; white_list::Set{Symbol} = Set{Symbol}(), black_list::Set{Symbol} = Set{Symbol}())
Find chains of Op1s in the given Decapode. A chain ends when the
target of the last Op1 is part of an Op2 or sum, or is a target
of multiple Op1s. Only operators with names included in the
allowable_ops set are allowed to be contracted. If the set is
empty then all operators are allowed.
of multiple Op1s. If a white list is provided, only chain those
operators. If a black list is provided, do not chain those operators.
"""
function find_chains(d::SummationDecapode;
white_list::Set{Symbol} = Set{Symbol}(),
Expand All @@ -225,7 +223,6 @@ function find_chains(d::SummationDecapode;
d[collect(Iterators.flatten(incident(d, collect(black_list), :op1))), :tgt]
]))


passes_white_list(x) = isempty(white_list) ? true : x white_list
passes_black_list(x) = x black_list

Expand Down Expand Up @@ -268,6 +265,7 @@ function find_chains(d::SummationDecapode;
end
return chains
end

function add_constant!(d::AbstractNamedDecapode, k::Symbol)
return add_part!(d, :Var, type=:Constant, name=k)
end
Expand Down

0 comments on commit 409a8fd

Please sign in to comment.