diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index f10bfec..9529578 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -14,7 +14,7 @@ Collage, collate, oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, ## acset SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, -contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, expand_operators, infer_state_names, recognize_types, +contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, expand_operators, infer_state_names, infer_terminal_names, recognize_types, resolve_overloads!, replace_names!, apply_inference_rule_op1!, apply_inference_rule_op2!, transfer_parents!, transfer_children!, diff --git a/src/acset.jl b/src/acset.jl index c1e924b..15742cc 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -202,7 +202,7 @@ end """ function infer_states(d::SummationDecapode) Find variables which have a time derivative or are not the source of a computation. -See also: [`infer_state_names`](@ref). +See also: [`infer_terminals`](@ref). """ function infer_states(d::SummationDecapode) parentless = filter(parts(d, :Var)) do v @@ -226,7 +226,7 @@ infer_state_names(d) = d[infer_states(d), :name] """ function infer_terminals(d::SummationDecapode) Find variables which have no children. -See also: [`infer_state_names`](@ref). +See also: [`infer_states`](@ref). """ function infer_terminals(d::SummationDecapode) filter(parts(d, :Var)) do v @@ -237,6 +237,13 @@ function infer_terminals(d::SummationDecapode) end end +""" function infer_terminal_names(d) + +Find names of variables which have no children. +See also: [`infer_terminals`](@ref). +""" +infer_terminal_names(d) = d[infer_terminals(d), :name] + """ function expand_operators(d::SummationDecapode) Find operations that are compositions, and expand them with intermediate variables. diff --git a/src/composition.jl b/src/composition.jl index cc6f802..4170379 100644 --- a/src/composition.jl +++ b/src/composition.jl @@ -220,6 +220,13 @@ oapply(r::RelationDiagram, pode::OpenSummationDecapode) = oapply(r, [pode]) # Default composition # ------------------- +# This helper function finds elements which appear in an array more than once. +function find_duplicates(vs::Vector{T}) where T + once, twice = Set{T}(), Set{T}() + foreach(v -> v ∈ once ? push!(twice,v) : push!(once,v), vs) + twice +end + # TODO: Upstream this to Catlab? function construct_relation_diagram(boxes::Vector{Symbol}, junctions::Vector{Vector{Symbol}}) tables = map(boxes, junctions) do b, j @@ -246,11 +253,11 @@ function default_composition_diagram(podes::Vector{D}, names::Vector{Symbol}, on pode[findall(!=(:Literal), pode[:type]), :name] end for (nln, name) in zip(non_lit_names, names) - allunique(nln) || error("Decapode $name contains repeated variable names: $nln.") + allunique(nln) || error("Decapode $name contains repeated variable names: $(find_duplicates(nln)).") end if only_states_terminals foreach(non_lit_names, podes) do nln, pode - outers = infer_state_names(pode) ∪ pode[infer_terminals(pode), :name] + outers = infer_state_names(pode) ∪ infer_terminal_names(pode) filter!(x -> x ∈ outers, nln) end end