Skip to content

Commit

Permalink
Improved performance of seperators function.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelsonric committed Nov 17, 2024
1 parent 41e6a4e commit 9fe79bc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 56 deletions.
60 changes: 7 additions & 53 deletions src/junction_trees/supernode_trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,36 +52,6 @@ function SupernodeTree(etree::EliminationTree, stype::SupernodeType=DEFAULT_SUPE
end


# Construct an elimination graph.
function eliminationgraph(stree::SupernodeTree)
graph = deepcopy(stree.graph)
n = length(stree.tree)

for i in 1:n - 1
for u in supernode(stree, i)[1:end - 1]
v = u + 1

for w in outneighbors(graph, u)
if v < w
add_edge!(graph, v, w)
end
end
end

u = last(supernode(stree, i))
v = first(supernode(stree, parentindex(stree.tree, i)))

for w in outneighbors(graph, u)
if v < w
add_edge!(graph, v, w)
end
end
end

graph
end


# Compute the width of a supernodal elimination tree.
function width(stree::SupernodeTree)
maximum(stree.degree[stree.representative])
Expand All @@ -99,45 +69,29 @@ end
# Compute the (unsorted) seperators of every node in T.
function seperators(stree::SupernodeTree)
n = length(stree.tree)
seperator = Vector{Vector{Int}}(undef, n)
graph = eliminationgraph(stree)

for i in 1:n
clique = collect(outneighbors(graph, stree.representative[i]))
filter!(j -> stree.ancestor[i] <= j, clique)
seperator[i] = clique
end

seperator
end


# Compute the (unsorted) seperators of every node in T.
function _seperators(stree::SupernodeTree)
n = length(stree.tree)
seperator = Vector{SparseIntSet}(undef, n)
seperator = Vector{Set{Int}}(undef, n)

for i in 1:n - 1
seperator[i] = SparseIntSet(stree.ancestor[i])
end
seperator[i] = Set(stree.ancestor[i])

seperator[n] = SparseIntSet()

for i in 1:n - 1
for v in outneighbors(stree.graph, stree.representative[i])
if stree.ancestor[i] < v
push!(seperator[i], v)
end
end
end

for i in 1:n - 2
j = parentindex(stree.tree, i)

for v in seperator[i]
if stree.ancestor[j] < v
push!(seperator[j], v)
end
end

end

map(collect, seperator)
seperator[n] = Set()
seperator
end
6 changes: 3 additions & 3 deletions test/JunctionTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ stree = SupernodeTree(graph, order, Node())
[17], # q
]

@test map(sort, seperators(stree)) == [
@test map(sort collect, seperators(stree)) == [
[2, 9, 10], # h i o
[9, 10], # i o
[9, 16], # i p
Expand Down Expand Up @@ -194,7 +194,7 @@ stree = SupernodeTree(graph, order, Maximal())
[13, 14, 15, 16, 17], # l m n p q
]

@test map(sort, seperators(stree)) == [
@test map(sort collect, seperators(stree)) == [
[9, 10], # i o
[9, 16], # i p
[6, 7], # c d
Expand Down Expand Up @@ -277,7 +277,7 @@ stree = SupernodeTree(graph, order, Fundamental())
[16, 17], # p q
]

@test map(sort, seperators(stree)) == [
@test map(sort collect, seperators(stree)) == [
[9, 10], # i o
[9, 16], # i p
[6, 7], # c d
Expand Down

0 comments on commit 9fe79bc

Please sign in to comment.