Skip to content

Commit

Permalink
Reduce cases of topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Sep 27, 2024
1 parent d408c26 commit 3cd624e
Showing 1 changed file with 18 additions and 38 deletions.
56 changes: 18 additions & 38 deletions src/graph_traversal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,59 +29,38 @@ edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) =
edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) =
:+

#XXX: This topological sort is O(n^2).
function topological_sort_edges(d::SummationDecapode)
visited_Var = falses(nparts(d, :Var))
visited_Var[start_nodes(d)] .= true
visited = Dict(:Op1 => falses(nparts(d, :Op1)),
:Op2 => falses(nparts(d, :Op2)), => falses(nparts(d, )))

# TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ
visited_1 = falses(nparts(d, :Op1))
visited_2 = falses(nparts(d, :Op2))
visited_Σ = falses(nparts(d, ))

# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = TraversalNode{Symbol}[]

for _ in 1:n_ops(d)
for op in parts(d, :Op1)
if !visited_1[op] && visited_Var[d[op, :src]]

visited_1[op] = true
visited_Var[d[op, :tgt]] = true

push!(op_order, TraversalNode(op, :Op1))
end
end

for op in parts(d, :Op2)
if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]]
visited_2[op] = true
visited_Var[d[op, :res]] = true
push!(op_order, TraversalNode(op, :Op2))
end
function visit(op, op_type)
if !visited[op_type][op] && all(visited_Var[edge_inputs(d,op,Val(op_type))])
visited[op_type][op] = true
visited_Var[edge_output(d,op,Val(op_type))] = true
push!(op_order, TraversalNode(op, op_type))
end
end

for op in parts(d, )
args = subpart(d, incident(d, op, :summation), :summand)
if !visited_Σ[op] && all(visited_Var[args])
visited_Σ[op] = true
visited_Var[d[op, :sum]] = true
push!(op_order, TraversalNode(op, ))
end
end
for _ in 1:n_ops(d)
visit.(parts(d,:Op1), :Op1)
visit.(parts(d,:Op2), :Op2)
visit.(parts(d,), )
end

@assert length(op_order) == n_ops(d)

op_order

Check warning on line 56 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L56

Added line #L56 was not covered by tests
end

function n_ops(d::SummationDecapode)
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )
end
n_ops(d::SummationDecapode) =
nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )

function start_nodes(d::SummationDecapode)
return vcat(infer_states(d), incident(d, :Literal, :type))
end
start_nodes(d::SummationDecapode) =
vcat(infer_states(d), incident(d, :Literal, :type))

function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
@match tsr.name begin
Expand All @@ -91,3 +70,4 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
_ => error("$(tsr.name) is a table without names")

Check warning on line 70 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L70

Added line #L70 was not covered by tests
end
end

0 comments on commit 3cd624e

Please sign in to comment.