diff --git a/Project.toml b/Project.toml index 4f64c78..6660fa2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,22 @@ name = "StructuredDecompositions" uuid = "ba32925c-6e4c-4640-bed9-b00febeea19a" authors = ["benjaminmerlinbumpus "] -version = "0.1.0" +version = "0.2.0" [deps] +AMD = "14f7f29c-3bd6-536c-9a0b-7339e30b5a3e" +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" +CuthillMcKee = "17f17636-5e38-52e3-a803-7ae3aaaf3da9" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] -Catlab = "^0.14" -MLStyle = "^0.4" -PartialFunctions = "^1.1" -julia = "^1.7" +Catlab = "0.16" +MLStyle = "0.4" +PartialFunctions = "1.1" +julia = "1.7" diff --git a/src/Decompositions.jl b/src/Decompositions.jl index 685ff8d..ef83687 100644 --- a/src/Decompositions.jl +++ b/src/Decompositions.jl @@ -139,8 +139,8 @@ adhesionSpans(d) = adhesionSpans(d, false) function elements_graph(el::Elements) F = FinFunctor(Dict(:V => :El, :E => :Arr), Dict(:src => :src, :tgt => :tgt), SchGraph, SchElements) - ΔF = DeltaMigration(F, Elements{Symbol}, Graph) - return ΔF(el) + ΔF = DeltaMigration(F) + return migrate(Graph, el, ΔF) end """Syntactic sugar for costrucitng the category of elements of a graph. @@ -152,8 +152,8 @@ function ∫(G::Elements) FinCat(elements_graph(G)) end #reverse direction of the edges function op_graph(g::Graph)::Graph F = FinFunctor(Dict(:V => :V, :E => :E), Dict(:src => :tgt, :tgt => :src), SchGraph, SchGraph) - ΔF = DeltaMigration(F, Graph, Graph) - return ΔF(g) + ΔF = DeltaMigration(F) + return migrate(Graph, g, ΔF) end """ diff --git a/src/StructuredDecompositions.jl b/src/StructuredDecompositions.jl index 58bee50..8c90cc6 100644 --- a/src/StructuredDecompositions.jl +++ b/src/StructuredDecompositions.jl @@ -4,6 +4,8 @@ module StructuredDecompositions include("Decompositions.jl") include("FunctorUtils.jl") include("DecidingSheaves.jl") +include("junction_trees/JunctionTrees.jl") +include("nested_uwds/NestedUWDs.jl") end diff --git a/src/junction_trees/JunctionTrees.jl b/src/junction_trees/JunctionTrees.jl new file mode 100644 index 0000000..87ae3d3 --- /dev/null +++ b/src/junction_trees/JunctionTrees.jl @@ -0,0 +1,83 @@ +module JunctionTrees + + +import AMD +import CuthillMcKee +import Metis + +using AbstractTrees +using Catlab.BasicGraphs +using DataStructures +using SparseArrays + +# Elimination Algorithms +export EliminationAlgorithm, AMDJL_AMD, CuthillMcKeeJL_RCM, MetisJL_ND, MCS + +# Supernodes +export Supernode, Node, MaximalSupernode, FundamentalSupernode + +# Orders +export Order + +# Elimination Trees +export EliminationTree +export getwidth, getsupernode, getsubtree, getlevel + +# Junction Trees +export JunctionTree +export getseperator, getresidual + + +# Add an element x to a sorted set v. +# Returns true if x ∉ v. +# Returns false if x ∈ v. +function insertsorted!(v::Vector, x::Integer) + i = searchsortedfirst(v, x) + + if i > length(v) || v[i] != x + insert!(v, i, x) + true + else + false + end +end + + +# Delete an element x from a sorted set v. +# Returns true if x ∈ v. +# Returns false if x ∉ v. +function deletesorted!(v::Vector, x::Integer) + i = searchsortedfirst(v, x) + + if i <= length(v) && v[i] == x + deleteat!(v, i) + true + else + false + end +end + + +# Delete the elements xs from a sorted set v. +# Returns true if xs and v intersect. +# Returns false if xs and v are disjoint. +function deletesorted!(v::Vector, xs::AbstractVector) + isintersecting = true + + for x in xs + isintersecting = deletesorted!(v, x) || isintersecting + end + + isintersecting +end + + +include("elimination_algorithms.jl") +include("supernodes.jl") +include("orders.jl") +include("trees.jl") +include("elimination_trees.jl") +include("junction_trees.jl") + + +end diff --git a/src/junction_trees/elimination_algorithms.jl b/src/junction_trees/elimination_algorithms.jl new file mode 100644 index 0000000..d302a74 --- /dev/null +++ b/src/junction_trees/elimination_algorithms.jl @@ -0,0 +1,45 @@ +""" + EliminationAlgorithm + +A graph elimination algorithm. The options are +- [`CuthillMcKeeJL_RCM`](@ref) +- [`AMDJL_AMD`](@ref) +- [`MetisJL_ND`](@ref) +- [`MCS`](@ref) +""" +abstract type EliminationAlgorithm end + + +""" + CuthillMcKeeJL_RCM <: EliminationAlgorithm + +The reverse Cuthill-McKee algorithm. Uses CuthillMckee.jl. +""" +struct CuthillMcKeeJL_RCM <: EliminationAlgorithm end + + +""" + AMDJL_AMD <: EliminationAlgorithm + +The approximate minimum degree algorithm. Uses AMD.jl. +""" +struct AMDJL_AMD <: EliminationAlgorithm end + + +""" + MetisJL_ND <: EliminationAlgorithm + +The nested dissection heuristic. Uses Metis.jl. +""" +struct MetisJL_ND <: EliminationAlgorithm end + + +""" + MCS <: EliminationAlgorithm + +The maximum cardinality search algorithm. +""" +struct MCS <: EliminationAlgorithm end + + +const DEFAULT_ELIMINATION_ALGORITHM = AMDJL_AMD() diff --git a/src/junction_trees/elimination_trees.jl b/src/junction_trees/elimination_trees.jl new file mode 100644 index 0000000..c5d57d9 --- /dev/null +++ b/src/junction_trees/elimination_trees.jl @@ -0,0 +1,319 @@ +# A supernodal elimination tree. +struct EliminationTree + order::Order + tree::Tree + firstsupernodelist::Vector{Int} + lastsupernodelist::Vector{Int} + subtreelist::Vector{Int} + width::Int +end + + +function EliminationTree( + order::Order, + tree::Tree, + supernodelist::AbstractVector, + subtreelist::AbstractVector, + width::Integer) + + n = length(order) + m = length(tree) + postorder = Order(n) + firstsupernodelist = Vector{Int}(undef, m) + lastsupernodelist = Vector{Int}(undef, m) + + i₂ = 0 + + for j in 1:m + supernode = supernodelist[j] + i₁ = i₂ + 1 + i₂ = i₂ + length(supernode) + firstsupernodelist[j] = i₁ + lastsupernodelist[j] = i₂ + postorder[i₁:i₂] .= supernode + end + + order = compose(postorder, order) + subtreelist = subtreelist[postorder] + + EliminationTree( + order, + tree, + firstsupernodelist, + lastsupernodelist, + subtreelist, + width) +end + + +# Construct a supernodal elimination tree. +# +# The complexity is +# 𝒪(m α(m, n) + n) +# where m = |E|, n = |V|, and α is the inverse Ackermann function. +function EliminationTree( + graph::AbstractSymmetricGraph, + order::Order, + supernode::Supernode=DEFAULT_SUPERNODE) + + etree = Tree(graph, order) + _, outdegreelist = getdegrees(graph, order, etree) + + supernodelist, subtreelist, parentlist = makestree( + etree, + outdegreelist, + supernode) + + n = nv(graph) + tree = Tree(subtreelist[n], parentlist) + postorder, tree = makepostorder(tree) + + supernodelist = supernodelist[postorder] + subtreelist = postorder.index[subtreelist] + width = maximum(outdegreelist) + + EliminationTree( + order, + tree, + supernodelist, + subtreelist, + width) +end + + +# Construct a supernodal elimination tree, first computing an elimination order. +function EliminationTree( + graph::AbstractSymmetricGraph, + algorithm::EliminationAlgorithm=DEFAULT_ELIMINATION_ALGORITHM, + supernode::Supernode=DEFAULT_SUPERNODE) + + order = Order(graph, algorithm) + EliminationTree(graph, order, supernode) +end + + +# Get the number of nodes in a supernodal elimination tree. +function Base.length(stree::EliminationTree) + length(stree.tree) +end + + +# Get the width of a supernodal elimination tree. +function getwidth(stree::EliminationTree) + stree.width +end + + +# Get the supernode at node i. +function getsupernode(stree::EliminationTree, i::Integer) + i₁ = stree.firstsupernodelist[i] + i₂ = stree.lastsupernodelist[i] + stree.order[i₁:i₂] +end + + +# Get the highest node containing a vertex v. +function getsubtree(stree::EliminationTree, v::Integer) + stree.subtreelist[stree.order.index[v]] +end + + +# Get the highest node containing vertices vs. +function getsubtree(stree::EliminationTree, vs::AbstractVector) + init = length(stree.order) + stree.subtreelist[minimum(stree.order.index[vs]; init)] +end + + +# Get the level of node i. +function getlevel(stree::EliminationTree, i::Integer) + getlevel(stree.tree, i) +end + + +# Evaluate whether node i₁ is a descendant of node i₂. +function AbstractTrees.isdescendant(stree::EliminationTree, i₁::Integer, i₂::Integer) + isdescendant(stree.tree, i₁, i₂) +end + + +# Compute the supernodes, parent function, and first ancestor of a +# supernodal elimination tree. +# +# The complexity is +# 𝒪(n) +# where n = |V|. +# +# doi:10.1561/2400000006 +# Algorithm 4.1: Maximal supernodes and supernodal elimination tree. +function makestree(etree::Tree, outdegrees::AbstractVector, supernode::Supernode) + n = length(etree) + sbt = Vector{Int}(undef, n) + snd = Vector{Int}[] + q = Int[] + a = Int[] + + for v in 1:n + w′ = findchild(etree, outdegrees, v, supernode) + + if isnothing(w′) + i = length(snd) + 1 + sbt[v] = i + push!(snd, [v]) + push!(q, 0) + push!(a, 0) + else + i = sbt[w′] + sbt[v] = i + push!(snd[i], v) + end + + for w in childindices(etree, v) + if w !== w′ + j = sbt[w] + q[j] = i + a[j] = v + end + end + end + + snd, sbt, q, a +end + + +# Find a child w of v such that +# v ∈ snd(w). +# If no such child exists, return nothing. +function findchild(etree::Tree, outdegrees::AbstractVector, v::Integer, ::Supernode) end + + +function findchild(etree::Tree, outdegrees::AbstractVector, v::Integer, ::MaximalSupernode) + for w in childindices(etree, v) + if outdegrees[w] == outdegrees[v] + 1 + return w + end + end +end + + +function findchild(etree::Tree, outdegrees::AbstractVector, v::Integer, ::FundamentalSupernode) + ws = childindices(etree, v) + + if length(ws) == 1 + w = only(ws) + + if outdegrees[w] == outdegrees[v] + 1 + return w + end + end +end + + +# Compute the row and column counts of a graph's elimination graph. +# +# The complexity is +# 𝒪(m α(m, n)) +# where m = |E|, n = |V|, and α is the inverse Ackermann function. +# +# doi:10.1137/S089547989223692 +# Figure 3: Implementation of algorithm to compute row and column counts +function getdegrees(graph::AbstractSymmetricGraph, order::Order, etree::Tree) + n = nv(graph) + forest = IntDisjointSets(n) + rvert = Vector{Int}(undef, n) + index = Vector{Int}(undef, n) + rvert .= index .= 1:n + + function FIND(p) + index[find_root!(forest, p)] + end + + function UNION(u, v) + w = max(u, v) + rvert[w] = root_union!(forest, rvert[u], rvert[v]) + index[rvert[w]] = w + end + + postorder, etree = makepostorder(etree) + graph = Graph(graph, compose(postorder, order)) + prev_p = Vector{Int}(undef, n) + prev_nbr = Vector{Int}(undef, n) + rc = Vector{Int}(undef, n) + wt = Vector{Int}(undef, n) + + for u in 1:n + prev_p[u] = 0 + prev_nbr[u] = 0 + rc[u] = 1 + wt[u] = isempty(childindices(etree, u)) + end + + for p in 1:n + if p != n + wt[parentindex(etree, p)] -= 1 + end + + for u in neighbors(graph, p) + if getfirstdescendant(etree, p) > prev_nbr[u] + wt[p] += 1 + p′ = prev_p[u] + + if p′ == 0 + rc[u] += getlevel(etree, p) - getlevel(etree, u) + else + q = FIND(p′) + rc[u] += getlevel(etree, p) - getlevel(etree, q) + wt[q] -= 1 + end + + prev_p[u] = p + end + + prev_nbr[u] = p + end + + if p != n + UNION(p, parentindex(etree, p)) + end + end + + cc = wt + + for v in 1:n - 1 + cc[parentindex(etree, v)] += cc[v] + end + + indegrees = rc[postorder.index] .- 1 + outdegrees = cc[postorder.index] .- 1 + indegrees, outdegrees +end + + +########################## +# Indexed Tree Interface # +########################## + + +function AbstractTrees.rootindex(stree::EliminationTree) + rootindex(stree.tree) +end + + +function AbstractTrees.parentindex(stree::EliminationTree, i::Integer) + parentindex(stree.tree, i) +end + + +function AbstractTrees.childindices(stree::EliminationTree, i::Integer) + childindices(stree.tree, i) +end + + +function AbstractTrees.NodeType(::Type{IndexNode{EliminationTree, Int}}) + HasNodeType() +end + + +function AbstractTrees.nodetype(::Type{IndexNode{EliminationTree, Int}}) + IndexNode{EliminationTree, Int} +end diff --git a/src/junction_trees/junction_trees.jl b/src/junction_trees/junction_trees.jl new file mode 100644 index 0000000..87e22e0 --- /dev/null +++ b/src/junction_trees/junction_trees.jl @@ -0,0 +1,189 @@ +# A junction tree. +struct JunctionTree + stree::EliminationTree + seperatorlist::Vector{Vector{Int}} +end + + +# Construct a tree decomposition. +function JunctionTree(graph::AbstractSymmetricGraph, stree::EliminationTree) + graph = makeeliminationgraph(graph, stree) + + n = length(stree) + seperatorlist = Vector{Vector{Int}}(undef, n) + seperatorlist[n] = [] + + for i in 1:n - 1 + v₁ = stree.firstsupernodelist[i] + v₂ = stree.lastsupernodelist[i] + bag = collect(neighbors(graph, v₁)) + sort!(bag) + seperatorlist[i] = bag[v₂ - v₁ + 1:end] + end + + JunctionTree(stree, seperatorlist) +end + + +# Reorient a juncton tree towards the given root. +function JunctionTree(root::Integer, jtree::JunctionTree) + m = length(jtree.stree.order) + n = length(jtree) + seperatorlist = Vector{Vector{Int}}(undef, n) + supernodelist = Vector{Vector{Int}}(undef, n) + subtreelist = Vector{Int}(undef, m) + + v₁ = jtree.stree.firstsupernodelist[root] + v₂ = jtree.stree.lastsupernodelist[root] + seperatorlist[n] = [] + supernodelist[n] = [v₁:v₂; jtree.seperatorlist[root]] + subtreelist[supernodelist[n]] .= n + + tree = Tree(root, jtree.stree.tree) + postorder, tree = makepostorder(tree) + + for i in 1:n - 1 + j = postorder[i] + v₁ = jtree.stree.firstsupernodelist[j] + v₂ = jtree.stree.lastsupernodelist[j] + + if isdescendant(jtree, root, j) + seperatorlist[i] = jtree.seperatorlist[postorder[parentindex(tree, i)]] + supernodelist[i] = [v₁:v₂; jtree.seperatorlist[j]] + deletesorted!(supernodelist[i], seperatorlist[i]) + else + seperatorlist[i] = jtree.seperatorlist[j] + supernodelist[i] = v₁:v₂ + end + + subtreelist[supernodelist[i]] .= i + end + + order = jtree.stree.order + width = jtree.stree.width + stree = EliminationTree(order, tree, supernodelist, subtreelist, width) + + for i in 1:n + seperatorlist[i] = stree.order.index[order[seperatorlist[i]]] + sort!(seperatorlist[i]) + end + + JunctionTree(stree, seperatorlist) +end + + +# Construct a tree decomposition, first computing an elimination order and a supernodal +# elimination tree. +function JunctionTree( + graph::AbstractSymmetricGraph, + algorithm::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, + supernode::Supernode=DEFAULT_SUPERNODE) + + stree = EliminationTree(graph, algorithm, supernode) + JunctionTree(graph, stree) +end + + +# Get the number of nodes in a junction tree. +function Base.length(jtree::JunctionTree) + length(jtree.stree) +end + + +# Get the width of a junction tree. +function getwidth(jtree::JunctionTree) + getwidth(jtree.stree) +end + + +# Get the seperator at node i. +function getseperator(jtree::JunctionTree, i::Integer) + jtree.stree.order[jtree.seperatorlist[i]] +end + + +# Get the residual at node i. +function getresidual(jtree::JunctionTree, i::Integer) + getsupernode(jtree.stree, i) +end + + +# Get the highest node containing the vertex v. +function getsubtree(jtree::JunctionTree, v::Union{Integer, AbstractVector}) + getsubtree(jtree.stree, v) +end + + +# Get the level of node i. +function getlevel(jtree::JunctionTree, i::Integer) + getlevel(jtree.stree, i) +end + + +# Evaluate whether node i₁ is a descendant of node i₂. +function AbstractTrees.isdescendant(jtree::JunctionTree, i₁::Integer, i₂::Integer) + isdescendant(jtree.stree, i₁, i₂) +end + + +# Construct an elimination graph. +function makeeliminationgraph(graph::AbstractSymmetricGraph, stree::EliminationTree) + n = length(stree) + graph = Graph(graph, stree.order) + + for i in 1:n - 1 + u₁ = stree.firstsupernodelist[i] + u₂ = stree.lastsupernodelist[i] + + for u in u₁:u₂ - 1 + v = u + 1 + + for w in neighbors(graph, u) + if v != w && !has_edge(graph, v, w) + add_edge!(graph, v, w) + end + end + end + + u = u₂ + v = stree.firstsupernodelist[parentindex(stree, i)] + + for w in neighbors(graph, u) + if v != w && !has_edge(graph, v, w) + add_edge!(graph, v, w) + end + end + end + + graph +end + + +########################## +# Indexed Tree Interface # +########################## + + +function AbstractTrees.rootindex(jtree::JunctionTree) + rootindex(jtree.stree) +end + + +function AbstractTrees.parentindex(jtree::JunctionTree, i::Integer) + parentindex(jtree.stree, i) +end + + +function AbstractTrees.childindices(jtree::JunctionTree, i::Integer) + childindices(jtree.stree, i) +end + + +function AbstractTrees.NodeType(::Type{IndexNode{JunctionTree, Int}}) + HasNodeType() +end + + +function AbstractTrees.nodetype(::Type{IndexNode{JunctionTree, Int}}) + IndexNode{JunctionTree, Int} +end diff --git a/src/junction_trees/orders.jl b/src/junction_trees/orders.jl new file mode 100644 index 0000000..48f2a67 --- /dev/null +++ b/src/junction_trees/orders.jl @@ -0,0 +1,170 @@ +# A total ordering of the numbers {1, ..., n}. +struct Order <: AbstractVector{Int} + order::Vector{Int} + index::Vector{Int} +end + + +# Given a vector σ, construct the order ≺, where +# σ(i₁) ≺ σ(i₂) +# if +# i₁ < i₂. +function Order(order::AbstractVector) + n = length(order) + index = Vector{Int}(undef, n) + + for i in 1:n + index[order[i]] = i + end + + Order(order, index) +end + + +# Construct an empty order of length n. +function Order(n::Integer) + order = Vector{Int}(undef, n) + index = Vector{Int}(undef, n) + Order(order, index) +end + + +# Construct an elimination order using the reverse Cuthill-McKee algorithm. Uses +# CuthillMcKee.jl. +function Order(graph::AbstractSymmetricGraph, ::CuthillMcKeeJL_RCM) + order = CuthillMcKee.symrcm(adjacencymatrix(graph)) + Order(order) +end + + +# Construct an elimination order using the approximate minimum degree algorithm. Uses +# AMD.jl. +function Order(graph::AbstractSymmetricGraph, ::AMDJL_AMD) + order = AMD.symamd(adjacencymatrix(graph)) + Order(order) +end + + +# Construct an elimination order using the nested dissection heuristic. Uses Metis.jl. +function Order(graph::AbstractSymmetricGraph, ::MetisJL_ND) + order, index = Metis.permutation(adjacencymatrix(graph)) + Order(order, index) +end + + +# Construct an elimination order using the maximum cardinality search algorithm. +function Order(graph::AbstractSymmetricGraph, ::MCS) + order, index = mcs(graph) + Order(order, index) +end + + +# Compose as permutations. +function compose(order₁::Order, order₂::Order) + order = order₂.order[order₁.order] + index = order₁.index[order₂.index] + Order(order, index) +end + + +# Evaluate whether +# n₁ < n₂ +# in the given order. +function Base.isless(order::Order, n₁::Integer, n₂::Integer) + order.index[n₁] < order.index[n₂] +end + + +# Compute a vertex elimination order using the maximum cardinality search algorithm. +# +# The complexity is +# 𝒪(m + n), +# where m = |E| and n = |V|. +# +# https://doi.org/10.1137/0213035 +# Maximum cardinality search +function mcs(graph::AbstractSymmetricGraph) + n = nv(graph) + α = Vector{Int}(undef, n) + α⁻¹ = Vector{Int}(undef, n) + size = Vector{Int}(undef, n) + set = Vector{Vector{Int}}(undef, n) + + set .= [[]] + size .= 1 + append!(set[1], vertices(graph)) + + i = n + j = 1 + + while i >= 1 + v = pop!(set[j]) + α[v] = i + α⁻¹[i] = v + size[v] = 0 + + for w in neighbors(graph, v) + if size[w] >= 1 + deletesorted!(set[size[w]], w) + size[w] += 1 + insertsorted!(set[size[w]], w) + end + end + + i -= 1 + j += 1 + + while j >= 1 && isempty(set[j]) + j -= 1 + end + end + + α⁻¹, α +end + + +# Construct the adjacency matrix of a graph. +function adjacencymatrix(graph::AbstractSymmetricGraph) + m = ne(graph) + n = nv(graph) + + colptr = ones(Int, n + 1) + rowval = sizehint!(Vector{Int}(), 2m) + + for j in 1:n + ns = collect(neighbors(graph, j)) + sort!(ns) + colptr[j + 1] = colptr[j] + length(ns) + append!(rowval, ns) + end + + nzval = ones(Int, length(rowval)) + SparseMatrixCSC(n, n, colptr, rowval, nzval) +end + + +############################ +# AbstractVector Interface # +############################ + + +function Base.size(order::Order) + (length(order.order),) +end + + +function Base.getindex(order::Order, i::Integer) + order.order[i] +end + + +function Base.setindex!(order::Order, v::Integer, i::Integer) + order.order[i] = v + order.index[v] = i + v +end + + +function Base.IndexStyle(::Type{Order}) + IndexLinear() +end diff --git a/src/junction_trees/supernodes.jl b/src/junction_trees/supernodes.jl new file mode 100644 index 0000000..af940bb --- /dev/null +++ b/src/junction_trees/supernodes.jl @@ -0,0 +1,36 @@ +""" + Supernode + +A type of supernode. The options are +- [`Node`](@ref) +- [`MaximalSupernode`](@ref) +- [`FundamentalSupernode`](@ref) +""" +abstract type Supernode end + + +""" + Node <: Supernode + +A single-vertex supernode. +""" +struct Node <: Supernode end + + +""" + MaximalSupernode <: Supernode + +A maximal supernode. +""" +struct MaximalSupernode <: Supernode end + + +""" + FundamentalSupernode <: Supernode + +A fundamental supernode. +""" +struct FundamentalSupernode <: Supernode end + + +const DEFAULT_SUPERNODE = MaximalSupernode() diff --git a/src/junction_trees/trees.jl b/src/junction_trees/trees.jl new file mode 100644 index 0000000..ab8655f --- /dev/null +++ b/src/junction_trees/trees.jl @@ -0,0 +1,232 @@ +# A rooted tree. +struct Tree + root::Int + parentlist::Vector{Int} + childrenlist::Vector{Vector{Int}} + levellist::Vector{Int} + firstdescendantlist::Vector{Int} +end + + +# Orient a tree towards the given root. +function Tree(root::Integer, tree::Tree) + i = root + parent = parentindex(tree, i) + parentlist = copy(tree.parentlist) + childrenlist = deepcopy(tree.childrenlist) + + while !isnothing(parent) + parentlist[parent] = i + push!(childrenlist[i], parent) + deletesorted!(childrenlist[parent], i) + i = parent + parent = parentindex(tree, i) + end + + Tree(root, parentlist, childrenlist) +end + + +# Construct a tree from a list of parent and a list of children. +function Tree(root::Integer, parentlist::AbstractVector, childrenlist::AbstractVector) + n = length(parentlist) + levellist = Vector{Int}(undef, n) + firstdescendantlist = Vector{Int}(undef, n) + Tree(root, parentlist, childrenlist, levellist, firstdescendantlist) +end + + +# Construct a tree from a list of parents. +function Tree(root::Integer, parentlist::AbstractVector) + n = length(parentlist) + childrenlist = Vector{Vector{Int}}(undef, n) + childrenlist .= [[]] + + for i in 1:n + if i != root + push!(childrenlist[parentlist[i]], i) + end + end + + Tree(root, parentlist, childrenlist) +end + + +# Construct an elimination tree. +function Tree(graph::AbstractSymmetricGraph, order::Order) + n = nv(graph) + parentlist = makeetree(graph, order) + @assert count(parentlist .== 0) == 1 + Tree(n, parentlist) +end + + +function Base.length(tree::Tree) + length(tree.parentlist) +end + + +# Compute the parent vector of the elimination tree of the elimination graph of a ordered +# graph. +# +# The complexity is +# 𝒪(m log(n)) +# where m = |E| and n = |V|. +# +# doi:10.1145/6497.6499 +# Algorithm 4.2: Elimination Tree by Path Compression +function makeetree(graph::AbstractSymmetricGraph, order::Order) + graph = Graph(graph, order) + + n = nv(graph) + parent = Vector{Int}(undef, n) + ancestor = Vector{Int}(undef, n) + + for i in 1:n + parent[i] = 0 + ancestor[i] = 0 + + for k in inneighbors(graph, i) + r = k + + while ancestor[r] != 0 && ancestor[r] != i + t = ancestor[r] + ancestor[r] = i + r = t + end + + if ancestor[r] == 0 + ancestor[r] = i + parent[r] = i + end + end + end + + parent +end + + +# Given an ordered graph +# (G, σ), +# construct a directed graph by ordering the edges in G from lower to higher index. +# +# The complexity is +# 𝒪(m) +# where m = |E|. +function BasicGraphs.Graph(graph::AbstractSymmetricGraph, order::Order) + n = nv(graph) + digraph = Graph(n) + + for v in vertices(graph) + i = order.index[v] + + for w in neighbors(graph, v) + j = order.index[w] + + if i < j + add_edge!(digraph, i, j) + end + end + end + + digraph +end + + +############## +# Postorders # +############## + + +# Get the level of node i. +# This function only works on postordered trees. +function getlevel(tree::Tree, i::Integer) + tree.levellist[i] +end + + +# Get the first descendant of node i. +# This function only works on postordered trees. +function getfirstdescendant(tree::Tree, i::Integer) + tree.firstdescendantlist[i] +end + + +# Evaluate whether node i₁ is a descendant of node i₂. +# This function only works on postordered trees. +function AbstractTrees.isdescendant(tree::Tree, i₁::Integer, i₂::Integer) + getfirstdescendant(tree, i₂) <= i₁ < i₂ +end + + +# Compute a postordering of a tree. +# +# The complexity is +# 𝒪(n) +# where n = |V|. +function makepostorder(tree::Tree) + n = length(tree) + order = Order(n) + parentlist = Vector{Int}(undef, n) + childrenlist = Vector{Vector{Int}}(undef, n) + levellist = Vector{Int}(undef, n) + firstdescendantlist = Vector{Int}(undef, n) + + root, nodes... = PreOrderDFS(IndexNode(tree)) + + order[n] = root.index + parentlist[n] = 0 + childrenlist[n] = [] + levellist[n] = 0 + + for (i, node) in enumerate(nodes) + j = n - i + order[j] = node.index + + k = order.index[parentindex(tree, node.index)] + parentlist[j] = k + childrenlist[j] = [] + pushfirst!(childrenlist[k], j) + levellist[j] = 1 + levellist[k] + end + + for i in 1:n + init = i + firstdescendantlist[i] = minimum(firstdescendantlist[childrenlist[i]]; init) + end + + tree = Tree(n, parentlist, childrenlist, levellist, firstdescendantlist) + order, tree +end + + +########################## +# Indexed Tree Interface # +########################## + + +function AbstractTrees.rootindex(tree::Tree) + tree.root +end + + +function AbstractTrees.parentindex(tree::Tree, i::Integer) + if i != rootindex(tree) + tree.parentlist[i] + end +end + + +function AbstractTrees.childindices(tree::Tree, i::Integer) + tree.childrenlist[i] +end + + +function AbstractTrees.NodeType(::Type{IndexNode{Tree, Int}}) + HasNodeType() +end + + +function AbstractTrees.nodetype(::Type{IndexNode{Tree, Int}}) + IndexNode{Tree, Int} +end diff --git a/src/nested_uwds/NestedUWDs.jl b/src/nested_uwds/NestedUWDs.jl new file mode 100644 index 0000000..5de7eca --- /dev/null +++ b/src/nested_uwds/NestedUWDs.jl @@ -0,0 +1,34 @@ +module NestedUWDs + + +using AbstractTrees +using Catlab.ACSetInterface +using Catlab.BasicGraphs +using Catlab.DirectedWiringDiagrams +using Catlab.DirectedWiringDiagrams: WiringDiagramACSet +using Catlab.MonoidalUndirectedWiringDiagrams +using Catlab.MonoidalUndirectedWiringDiagrams: UntypedHypergraphDiagram +using Catlab.RelationalPrograms +using Catlab.RelationalPrograms: TypedUnnamedRelationDiagram +using Catlab.Theories +using Catlab.UndirectedWiringDiagrams +using Catlab.WiringDiagramAlgebras + +using ..JunctionTrees +using ..JunctionTrees: insertsorted!, DEFAULT_ELIMINATION_ALGORITHM, DEFAULT_SUPERNODE + +# Elimination Algorithms +export EliminationAlgorithm, AMDJL_AMD, CuthillMcKeeJL_RCM, MetisJL_ND, MCS + +# Supernodes +export Supernode, Node, MaximalSupernode, FundamentalSupernode + +# Nested UWDs +export NestedUWD +export evalschedule, makeschedule, makeoperations + + +include("nested_uwds.jl") + + +end diff --git a/src/nested_uwds/nested_uwds.jl b/src/nested_uwds/nested_uwds.jl new file mode 100644 index 0000000..54dd648 --- /dev/null +++ b/src/nested_uwds/nested_uwds.jl @@ -0,0 +1,312 @@ +""" + NestedUWD{T, B, V} + +An undirected wiring diagram, represented as a nested collected of undirected wiring +diagrams. +""" +struct NestedUWD{T, B, V} + diagram::TypedUnnamedRelationDiagram{T, B, V} + jtree::JunctionTree + assignmentlist::Vector{Int} + assignmentindex::Vector{Vector{Int}} +end + + +function NestedUWD( + diagram::D, + jtree::JunctionTree, + assignmentlist::AbstractVector, + assignmentindex::AbstractVector) where D <: UndirectedWiringDiagram + + T, B, V = getattributetypes(D) + relation = TypedUnnamedRelationDiagram{T, B, V}() + copy_parts!(relation, diagram) + NestedUWD{T, B, V}(relation, jtree, assignmentlist, assignmentindex) +end + + +function NestedUWD(diagram::UndirectedWiringDiagram, jtree::JunctionTree) + n = nparts(diagram, :Box) + m = length(jtree) + assignmentlist = Vector{Int}(undef, n) + assignmentindex = Vector{Vector{Int}}(undef, m) + assignmentindex .= [[]] + + for b in 1:n + i = getsubtree(jtree, diagram[incident(diagram, b, :box), :junction]) + assignmentlist[b] = i + push!(assignmentindex[i], b) + end + + NestedUWD(diagram, jtree, assignmentlist, assignmentindex) +end + + +""" + NestedUWD( + diagram::UndirectedWiringDiagram, + [, algorithm::Union{Order, EliminationAlgorithm}] + [, supernode::Supernode]) + +Construct a nested undirected wiring diagram. +""" +function NestedUWD( + diagram::UndirectedWiringDiagram, + algorithm::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, + supernode::Supernode=DEFAULT_SUPERNODE) + + jtree = JunctionTree(diagram, algorithm, supernode) + NestedUWD(diagram, jtree) +end + + +# Construct a tree decomposition of the line graph of an undirected wiring diagram. +function JunctionTree( + diagram::UndirectedWiringDiagram, + algorithm::Union{Order, EliminationAlgorithm}=DEFAULT_ELIMINATION_ALGORITHM, + supernode::Supernode=DEFAULT_SUPERNODE) + + graph = makegraph(diagram) + jtree = JunctionTree(graph, algorithm, supernode) + + query = diagram[:outer_junction] + JunctionTree(getsubtree(jtree, query), jtree) +end + + +# Construct the line graph of an undirected wiring diagram. +function makegraph(diagram::UndirectedWiringDiagram) + n = nparts(diagram, :Junction) + m = nparts(diagram, :Box) + graph = SymmetricGraph(n) + + for i in 1:m + junctions = diagram[incident(diagram, i, :box), :junction] + l = length(junctions) + + for i₁ in 1:l - 1 + j₁ = junctions[i₁] + + for i₂ in i₁ + 1:l + j₂ = junctions[i₂] + + if !has_edge(graph, j₁, j₂) + add_edge!(graph, j₁, j₂) + end + end + end + end + + junctions = diagram[:, :outer_junction] + l = length(junctions) + + for i₁ in 1:l - 1 + j₁ = junctions[i₁] + + for i₂ in i₁ + 1:l + j₂ = junctions[i₂] + + if !has_edge(graph, j₁, j₂) + add_edge!(graph, j₁, j₂) + end + end + end + + graph +end + + +""" + makeschedule(nuwd::NestedUWD) + +Construct a directed wiring diagram that represents the nesting structure of a nested UWD. +""" +function makeschedule(nuwd::NestedUWD{<:Any, T}) where T + m = length(nuwd.assignmentlist) + n = length(nuwd.jtree) + + parents = map(1:n - 1) do i + parentindex(nuwd.jtree, i) + end + + costs = map(1:n) do i + length(getresidual(nuwd.jtree, i)) + length(getseperator(nuwd.jtree, i)) + end + + schedule = WiringDiagramACSet{T, Nothing, Union{Int, AbstractBox}, DataType}() + + add_parts!(schedule, :Box, n) + add_parts!(schedule, :Wire, n - 1) + add_parts!(schedule, :InPort, m + n - 1) + add_parts!(schedule, :InWire, m) + add_parts!(schedule, :OutPort, n) + add_parts!(schedule, :OutWire, 1) + add_parts!(schedule, :OuterInPort, m) + add_parts!(schedule, :OuterOutPort, 1) + + schedule[:, :src] = 1:n - 1 + schedule[:, :tgt] = m + 1:m + n - 1 + schedule[:, :in_src] = 1:m + schedule[:, :in_tgt] = 1:m + schedule[:, :out_src] = n:n + schedule[:, :out_tgt] = 1:1 + schedule[:, :in_port_box] = [nuwd.assignmentlist; parents] + schedule[:, :out_port_box] = 1:n + + schedule[:, :value] = costs + schedule[:, :box_type] = Box{Int} + schedule[:, :outer_in_port_type] = nuwd.diagram[:, :name] + + Theory = ThSymmetricMonoidalCategory.Meta.T + WiringDiagram{Theory, T, Nothing, Int}(schedule, nothing) +end + + +""" + function evalschedule( + f, + nuwd::NestedUWD, + generators::Union{AbstractVector, AbstractDict} + [, operations::AbstractVector]) + +Evaluate an undirected wiring diagrams given a set of generators for the boxes. The +optional first argument `f` should be callable with the signature +``` + f(diagram, generators) +``` +where `diagram` is an undirected wiring diagram, and `generators` is a vector. If `f` is not +specified, then it defaults to `oapply`. +""" +function evalschedule( + f, + nuwd::NestedUWD, + generators::AbstractVector{T}, + operations::AbstractVector=makeoperations(nuwd)) where T + + n = length(nuwd.jtree) + mailboxes = Vector{T}(undef, n) + + for i in 1:n + g₁ = generators[nuwd.assignmentindex[i]] + g₂ = mailboxes[childindices(nuwd.jtree, i)] + mailboxes[i] = f(operations[i], [g₁; g₂]) + end + + mailboxes[n] +end + + +function evalschedule( + f, + nuwd::NestedUWD, + generators::AbstractDict{<:Any, T}, + operations::AbstractVector=makeoperations(nuwd)) where T + + g = generators + n = nparts(nuwd.diagram, :Box) + generators = Vector{T}(undef, n) + + for i in 1:n + generators[i] = g[nuwd.diagram[i, :name]] + end + + evalschedule(f, nuwd, generators, operations) +end + + +function evalschedule( + nuwd::NestedUWD, + generators::Union{AbstractVector, AbstractDict}, + operations::AbstractVector=makeoperations(nuwd)) + + evalschedule(oapply, nuwd, generators, operations) +end + + +# For each node i of a nested UWD, construct the undirected wiring diagram corresponding to i. +function makeoperations(nuwd::NestedUWD) + m = length(nuwd.jtree) + + map(1:m) do i + makeoperation(nuwd, i) + end +end + + +# Construct the undirected wiring diagram corresponding to node i of a nested UWD. +function makeoperation(nuwd::NestedUWD{T, B, V}, i::Integer) where {T, B, V} + function findjunction(j::Integer) + v = nuwd.jtree.stree.order.index[j] + v₁ = nuwd.jtree.stree.firstsupernodelist[i] + v₂ = nuwd.jtree.stree.lastsupernodelist[i] + + if v <= v₂ + v - v₁ + 1 + else + v₂ - v₁ + 1 + searchsortedfirst(nuwd.jtree.seperatorlist[i], v) + end + end + + residual = getresidual(nuwd.jtree, i) + seperator = getseperator(nuwd.jtree, i) + m = length(residual) + n = length(seperator) + + operation = TypedUnnamedRelationDiagram{T, B, V}() + add_parts!(operation, :Junction, m + n) + + operation[1:m, :junction_type] = nuwd.diagram[residual, :junction_type] + operation[1:m, :variable] = nuwd.diagram[residual, :variable] + operation[m + 1:m + n, :junction_type] = nuwd.diagram[seperator, :junction_type] + operation[m + 1:m + n, :variable] = nuwd.diagram[seperator, :variable] + + if i < length(nuwd.jtree) + for j in seperator + p′ = add_part!(operation, :OuterPort) + operation[p′, :outer_junction] = m + p′ + operation[p′, :outer_port_type] = nuwd.diagram[j, :junction_type] + end + else + for j in nuwd.diagram[:outer_junction] + p′ = add_part!(operation, :OuterPort) + operation[p′, :outer_junction] = findjunction(j) + operation[p′, :outer_port_type] = nuwd.diagram[j, :junction_type] + end + end + + for b in nuwd.assignmentindex[i] + b′ = add_part!(operation, :Box) + operation[b′, :name] = nuwd.diagram[b, :name] + + for j in nuwd.diagram[incident(nuwd.diagram, b, :box), :junction] + p′ = add_part!(operation, :Port) + operation[p′, :box] = b′ + operation[p′, :junction] = findjunction(j) + operation[p′, :port_type] = nuwd.diagram[j, :junction_type] + end + end + + for b in childindices(nuwd.jtree, i) + b′ = add_part!(operation, :Box) + + for j in getseperator(nuwd.jtree, b) + p′ = add_part!(operation, :Port) + operation[p′, :box] = b′ + operation[p′, :junction] = findjunction(j) + operation[p′, :port_type] = nuwd.diagram[j, :junction_type] + end + end + + operation +end + + +# Get the attribute types of an undirected wiring diagram. +function getattributetypes(::Type{<:UntypedRelationDiagram{B, V}}) where {B, V} + Nothing, B, V +end + + +function getattributetypes(::Type{<:TypedRelationDiagram{T, B, V}}) where {T, B, V} + T, B, V +end diff --git a/test/JunctionTrees.jl b/test/JunctionTrees.jl new file mode 100644 index 0000000..9001215 --- /dev/null +++ b/test/JunctionTrees.jl @@ -0,0 +1,285 @@ +using AbstractTrees +using Catlab.BasicGraphs +using Catlab.RelationalPrograms +using Catlab.UndirectedWiringDiagrams +using StructuredDecompositions.JunctionTrees +using Test + + +# Vandenberghe and Andersen +# Chordal Graphs and Semidefinite Optimization +graph = SymmetricGraph(17) + +add_edges!(graph, + [1, 1, 1, 1, 2, 2, 5, 5, 6, 6, 7, 7, 7, 10, 10, 10, 10, 12, 12, 12, 12, 15], + [3, 4, 5, 15, 3, 4, 9, 16, 9, 16, 8, 9, 15, 11, 13, 14, 17, 13, 14, 16, 17, 17]) + +order = JunctionTrees.Order(graph, CuthillMcKeeJL_RCM()) +@test order == [2, 14, 13, 11, 4, 3, 12, 10, 16, 1, 17, 5, 6, 15, 9, 7, 8] + +order = JunctionTrees.Order(graph, AMDJL_AMD()) +@test order == [8, 11, 7, 2, 4, 3, 1, 6, 13, 14, 10, 12, 17, 16, 5, 9, 15] + +order = JunctionTrees.Order(graph, MetisJL_ND()) +@test order == [11, 17, 14, 13, 10, 12, 8, 6, 7, 5, 4, 3, 9, 2, 1, 16, 15] + +order = JunctionTrees.Order(graph, MCS()) +@test order == [2, 3, 4, 8, 1, 5, 6, 9, 7, 11, 13, 10, 14, 16, 12, 15, 17] + +order = JunctionTrees.Order(1:17) +parent = JunctionTrees.makeetree(graph, order) + +# Figure 4.2 +@test parent == [3, 3, 4, 5, 9, 9, 8, 9, 15, 11, 13, 13, 14, 16, 16, 17, 0] + +etree = JunctionTrees.Tree(17, parent) +indegrees, outdegrees = JunctionTrees.getdegrees(graph, order, etree) + +@test indegrees == [0, 0, 2, 3, 3, 0, 0, 1, 4, 0, 1, 0, 3, 4, 7, 7, 7] +@test outdegrees == [4, 2, 3, 2, 3, 2, 3, 2, 2, 4, 3, 4, 3, 2, 2, 1, 0] + +# Figure 4.3 +snd, sbt, q, a = JunctionTrees.makestree(etree, outdegrees, Node()) + +@test snd == [ + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9], + [10], + [11], + [12], + [13], + [14], + [15], + [16], + [17]] + +@test sbt == 1:17 +@test q == [3, 3, 4, 5, 9, 9, 8, 9, 15, 11, 13, 13, 14, 16, 16, 17, 0] +@test a == [3, 3, 4, 5, 9, 9, 8, 9, 15, 11, 13, 13, 14, 16, 16, 17, 0] + +# Figure 4.7 (left) +snd, sbt, q, a = JunctionTrees.makestree(etree, outdegrees, MaximalSupernode()) + +@test snd == [ + [1, 3, 4], + [2], + [5, 9], + [6], + [7, 8], + [10, 11], + [12, 13, 14, 16, 17], + [15] ] + +@test sbt == [1, 2, 1, 1, 3, 4, 5, 5, 3, 6, 6, 7, 7, 7, 8, 7, 7] +@test q == [3, 1, 8, 3, 3, 7, 0, 7] +@test a == [5, 3, 15, 9, 9, 13, 0, 16] + +# Figure 4.9 +snd, sbt, q, a = JunctionTrees.makestree(etree, outdegrees, FundamentalSupernode()) + +@test snd == [ + [1], + [2], + [3, 4], + [5], + [6], + [7, 8], + [9], + [10, 11], + [12], + [13, 14], + [15], + [16, 17] ] + +@test sbt == [1, 2, 3, 3, 4, 5, 6, 6, 7, 8, 8, 9, 10, 10, 11, 12, 12] +@test q == [3, 3, 4, 7, 7, 7, 11, 10, 10, 12, 12, 0] +@test a == [3, 3, 5, 9, 9, 9, 15, 13, 13, 16, 16, 0] + +# Figure 4.3 +jtree = JunctionTree(graph, order, Node()) + +@test getresidual.([jtree], getsubtree.([jtree], 1:17)) == [ + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9], + [10], + [11], + [12], + [13], + [14], + [15], + [16], + [17] ] + +@test getseperator.([jtree], getsubtree.([jtree], 1:17)) == [ + [3, 4, 5, 15], + [3, 4], + [4, 5, 15], + [5, 15], + [9, 15, 16], + [9, 16], + [8, 9, 15], + [9, 15], + [15, 16], + [11, 13, 14, 17], + [13, 14, 17], + [13, 14, 16, 17], + [14, 16, 17], + [16, 17], + [16, 17], + [17], + [] ] + +@test getlevel.([jtree], getsubtree.([jtree], 1:17)) == [ + 7, + 7, + 6, + 5, + 4, + 4, + 5, + 4, + 3, + 5, + 4, + 4, + 3, + 2, + 2, + 1, + 0 ] + +@test isdescendant(jtree, getsubtree(jtree, 5), getsubtree(jtree, 15)) +@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 5)) +@test !isdescendant(jtree, getsubtree(jtree, 10), getsubtree(jtree, 15)) +@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 10)) +@test !isdescendant(jtree, getsubtree(jtree, 1), getsubtree(jtree, 1)) +@test getwidth(jtree) == 4 + +# Figure 4.7 (left) +jtree = JunctionTree(graph, order, MaximalSupernode()) + +@test getresidual.([jtree], getsubtree.([jtree], 1:17)) == [ + [1, 3, 4], + [2], + [1, 3, 4], + [1, 3, 4], + [5, 9], + [6], + [7, 8], + [7, 8], + [5, 9], + [10, 11], + [10, 11], + [12, 13, 14, 16, 17], + [12, 13, 14, 16, 17], + [12, 13, 14, 16, 17], + [15], + [12, 13, 14, 16, 17], + [12, 13, 14, 16, 17]] + +@test getseperator.([jtree], getsubtree.([jtree], 1:17)) == [ + [5, 15], + [3, 4], + [5, 15], + [5, 15], + [15, 16], + [9, 16], + [9, 15], + [9, 15], + [15, 16], + [13, 14, 17], + [13, 14, 17], + [], + [], + [], + [16, 17], + [], + []] + +@test getlevel.([jtree], getsubtree.([jtree], 1:17)) == [ + 3, + 4, + 3, + 3, + 2, + 3, + 3, + 3, + 2, + 1, + 1, + 0, + 0, + 0, + 1, + 0, + 0 ] + +@test isdescendant(jtree, getsubtree(jtree, 5), getsubtree(jtree, 15)) +@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 5)) +@test !isdescendant(jtree, getsubtree(jtree, 10), getsubtree(jtree, 15)) +@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 10)) +@test !isdescendant(jtree, getsubtree(jtree, 1), getsubtree(jtree, 1)) +@test getwidth(jtree) == 4 + +# Figure 4.9 +jtree = JunctionTree(graph, order, FundamentalSupernode()) + +@test getresidual.([jtree], getsubtree.([jtree], 1:17)) == [ + [1], + [2], + [3, 4], + [3, 4], + [5], + [6], + [7, 8], + [7, 8], + [9], + [10, 11], + [10, 11], + [12], + [13, 14], + [13, 14], + [15], + [16, 17], + [16, 17]] + +@test getseperator.([jtree], getsubtree.([jtree], 1:17)) == [ + [3, 4, 5, 15], + [3, 4], + [5, 15], + [5, 15], + [9, 15, 16], + [9, 16], + [9, 15], + [9, 15], + [15, 16], + [13, 14, 17], + [13, 14, 17], + [13, 14, 16, 17], + [16, 17], + [16, 17], + [16, 17], + [], + []] + +@test isdescendant(jtree, getsubtree(jtree, 5), getsubtree(jtree, 15)) +@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 5)) +@test !isdescendant(jtree, getsubtree(jtree, 10), getsubtree(jtree, 15)) +@test !isdescendant(jtree, getsubtree(jtree, 15), getsubtree(jtree, 10)) +@test !isdescendant(jtree, getsubtree(jtree, 1), getsubtree(jtree, 1)) +@test getwidth(jtree) == 4 diff --git a/test/NestedUWDs.jl b/test/NestedUWDs.jl new file mode 100644 index 0000000..e156351 --- /dev/null +++ b/test/NestedUWDs.jl @@ -0,0 +1,122 @@ +using Catlab.RelationalPrograms +using Catlab.UndirectedWiringDiagrams +using LinearAlgebra +using StructuredDecompositions.NestedUWDs +using Test + + +# CategoricalTensorNetworks.jl +# https://github.com/AlgebraicJulia/CategoricalTensorNetworks.jl/ +function contract_tensor_network(d::UndirectedWiringDiagram, + tensors::AbstractVector{<:AbstractArray}) + @assert nboxes(d) == length(tensors) + juncs = [junction(d, ports(d, b)) for b in boxes(d)] + j_out = junction(d, ports(d, outer=true), outer=true) + contract_tensor_network(tensors, juncs, j_out) +end + + +function contract_tensor_network(tensors::AbstractVector{<:AbstractArray{T}}, + juncs::AbstractVector, j_out) where T + # Handle important binary case with specialized code. + if length(tensors) == 2 && length(juncs) == 2 + return contract_tensor_network(Tuple(tensors), Tuple(juncs), j_out) + end + + jsizes = Tuple(infer_junction_sizes(tensors, juncs, j_out)) + juncs, j_out = map(Tuple, juncs), Tuple(j_out) + C = zeros(T, Tuple(jsizes[j] for j in j_out)) + for index in CartesianIndices(jsizes) + x = one(T) + for (A, junc) in zip(tensors, juncs) + x *= A[(index[j] for j in junc)...] + end + C[(index[j] for j in j_out)...] += x + end + return C +end + + +function contract_tensor_network( # Binary case. + (A, B)::Tuple{<:AbstractArray{T},<:AbstractArray{T}}, + (jA, jB), j_out) where T + jsizes = Tuple(infer_junction_sizes((A, B), (jA, jB), j_out)) + jA, jB, j_out = Tuple(jA), Tuple(jB), Tuple(j_out) + C = zeros(T, Tuple(jsizes[j] for j in j_out)) + for index in CartesianIndices(jsizes) + C[(index[j] for j in j_out)...] += + A[(index[j] for j in jA)...] * B[(index[j] for j in jB)...] + end + return C +end + + +function infer_junction_sizes(tensors, juncs, j_out) + @assert length(tensors) == length(juncs) + njunc = maximum(Iterators.flatten((Iterators.flatten(juncs), j_out))) + jsizes = fill(-1, njunc) + for (A, junc) in zip(tensors, juncs) + for (i, j) in enumerate(junc) + if jsizes[j] == -1 + jsizes[j] = size(A, i) + else + @assert jsizes[j] == size(A, i) + end + end + end + @assert all(s >= 0 for s in jsizes) + jsizes +end + + +# out[v,z] = A[v,w] * B[w,x] * C[x,y] * D[y,z] +diagram = @relation (v, z) begin + A(v, w) + B(w, x) + C(x, y) + D(y, z) +end + +nuwd = NestedUWD(diagram) +A, B, C, D = map(randn, [(3, 4), (4, 5), (5, 6), (6, 7)]) +generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B, :C => C, :D => D) +out = evalschedule(contract_tensor_network, nuwd, generators) +@test out ≈ A * B * C * D + +# out[] = A[w,x] * B[x,y] * C[y,z] * D[z,w] +diagram = @relation () begin + A(w, x) + B(x, y) + C(y, z) + D(z, w) +end + +nuwd = NestedUWD(diagram) +A, B, C, D = map(randn, [(10, 5), (5, 5), (5, 5), (5, 10)]) +generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B, :C => C, :D => D) +out = evalschedule(contract_tensor_network, nuwd, generators) +@test out[] ≈ tr(A * B * C * D) + +# out[w,x,y,z] = A[w,x] * B[y,z] +diagram = @relation (w, x, y, z) begin + A(w, x) + B(y, z) +end + +nuwd = NestedUWD(diagram) +A, B = map(randn, [(3, 4), (5, 6)]) +generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B) +out = evalschedule(contract_tensor_network, nuwd, generators) +@test out ≈ (reshape(A, (3, 4, 1, 1)) .* reshape(B, (1, 1, 5, 6))) + +# out[] = A[x,y] * B[x,y] +diagram = @relation () begin + A(x, y) + B(x, y) +end + +nuwd = NestedUWD(diagram) +A, B = map(randn, [(5, 5), (5, 5)]) +generators = Dict{Symbol, Array{Float64}}(:A => A, :B => B) +out = evalschedule(contract_tensor_network, nuwd, generators) +@test out[] ≈ dot(vec(A), vec(B)) diff --git a/test/Project.toml b/test/Project.toml index e0280cd..7d094ed 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" StructuredDecompositions = "ba32925c-6e4c-4640-bed9-b00febeea19a" diff --git a/test/runtests.jl b/test/runtests.jl index 5a68b88..6fdd6b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,4 +10,12 @@ end @testset "FunctorUtils" begin include("FunctorUtils.jl") -end \ No newline at end of file +end + +@testset "JunctionTrees" begin + include("JunctionTrees.jl") +end + +@testset "NestedUWDs" begin + include("NestedUWDs.jl") +end