From 867a7d2ce6c58ef62681d09bedc5bb0a3d21d962 Mon Sep 17 00:00:00 2001 From: Daniel Olsen Date: Wed, 13 Apr 2022 09:13:12 -0700 Subject: [PATCH] refactor: use networkx traversal instead of homebrew --- prereise/gather/griddata/hifld/data_access/load.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/prereise/gather/griddata/hifld/data_access/load.py b/prereise/gather/griddata/hifld/data_access/load.py index 6e3b2e454..aacb3fc03 100644 --- a/prereise/gather/griddata/hifld/data_access/load.py +++ b/prereise/gather/griddata/hifld/data_access/load.py @@ -261,26 +261,18 @@ def _generate_linear_spanning_tree(segments): # The minimum spanning tree wasn't good, go to the next iteration pass - # Identify an endpoint of the minimum spanning tree + # Identify an endpoint of the linear minimum spanning tree node_degrees = dict(nx.degree(tree)) endpoint = min(k for k, v in node_degrees.items() if v == 1) - last_endpoint = None # Traverse the tree, adding segments in order joined_segments = [] - while True: - next_endpoints = list(set(tree[endpoint]) - {last_endpoint}) - assert len(next_endpoints) == 1 - next_endpoint = next_endpoints[0] - (seg1, which1), (seg2, _) = endpoint, next_endpoints[0] + traversal = nx.algorithms.traversal.dfs_edges(tree, source=endpoint) + for (seg1, which1), (seg2, _) in traversal: if seg1 == seg2: if which1 == "start": joined_segments += linear_segments[seg1] else: joined_segments += linear_segments[seg1][::-1] - if node_degrees[next_endpoint] == 1: - break - last_endpoint = endpoint - endpoint = next_endpoint return joined_segments