diff --git a/prereise/gather/griddata/hifld/data_access/load.py b/prereise/gather/griddata/hifld/data_access/load.py index 6e3b2e454..aacb3fc03 100644 --- a/prereise/gather/griddata/hifld/data_access/load.py +++ b/prereise/gather/griddata/hifld/data_access/load.py @@ -261,26 +261,18 @@ def _generate_linear_spanning_tree(segments): # The minimum spanning tree wasn't good, go to the next iteration pass - # Identify an endpoint of the minimum spanning tree + # Identify an endpoint of the linear minimum spanning tree node_degrees = dict(nx.degree(tree)) endpoint = min(k for k, v in node_degrees.items() if v == 1) - last_endpoint = None # Traverse the tree, adding segments in order joined_segments = [] - while True: - next_endpoints = list(set(tree[endpoint]) - {last_endpoint}) - assert len(next_endpoints) == 1 - next_endpoint = next_endpoints[0] - (seg1, which1), (seg2, _) = endpoint, next_endpoints[0] + traversal = nx.algorithms.traversal.dfs_edges(tree, source=endpoint) + for (seg1, which1), (seg2, _) in traversal: if seg1 == seg2: if which1 == "start": joined_segments += linear_segments[seg1] else: joined_segments += linear_segments[seg1][::-1] - if node_degrees[next_endpoint] == 1: - break - last_endpoint = endpoint - endpoint = next_endpoint return joined_segments