Skip to content

Commit

Permalink
Update scripts to access stored gt tracks
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed May 7, 2024
1 parent 51352f1 commit f5bdc60
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 47 deletions.
26 changes: 1 addition & 25 deletions scripts/get_tracking_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,17 @@ def load_tracking_segmentation(experiment):
if experiment.startswith("vit"):
if experiment == "vit_l":
seg_path = os.path.join(result_dir, "vit_l.tif")
seg = imageio.imread(seg_path)
# HACK
ignore_labels = [8, 44, 57, 102, 50]

elif experiment == "vit_l_lm":
seg_path = os.path.join(result_dir, "vit_l_lm.tif")
seg = imageio.imread(seg_path)
# HACK
ignore_labels = []

elif experiment == "vit_l_specialist":
seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif")
seg = imageio.imread(seg_path)
# HACK
ignore_labels = [88, 45, 30, 46]

# elif experiment == "trackmate_stardist":
# seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif")
# seg = imageio.imread(seg_path)

else:
raise ValueError(experiment)

# HACK:
# we remove some labels as they have a weird lineage, is creating issues for creating the graph
# (e.g. frames where the object exists: 1, 2, 4, 5, 6)
seg[np.isin(seg, ignore_labels)] = 0

seg = imageio.imread(seg_path)
return seg

else: # return the result directory for stardist
Expand Down Expand Up @@ -117,11 +100,4 @@ def get_tracking_data():
curr_frames = v["frames"]
v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames]

# HACK:
# we remove label with id 62 as it has a weird lineage, is creating issues for creating the graph
ignore_labels = [62, 87, 92, 99, 58]
labels[np.isin(labels, ignore_labels)] = 0
for _label in ignore_labels:
curr_lineages.pop(_label)

return raw, labels, curr_lineages, chosen_frames
108 changes: 108 additions & 0 deletions scripts/gt_tracks.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
,Cell_ID,Start,End,Parent_ID
0,1,0,23,0
1,2,0,23,0
2,3,0,23,0
3,4,0,23,0
4,5,0,20,0
5,6,0,5,0
6,7,0,23,0
7,8,0,23,0
8,9,0,23,0
9,10,0,4,0
10,11,0,22,0
11,12,0,23,0
12,13,0,3,0
13,14,0,23,0
14,15,0,3,0
15,16,0,23,0
16,17,0,23,0
17,18,0,23,0
18,19,0,3,0
19,20,0,23,0
20,21,0,23,0
21,22,0,23,0
22,23,0,22,0
23,24,0,3,0
24,25,0,23,0
25,26,0,23,0
26,27,0,4,0
27,28,0,5,0
28,29,0,23,0
29,30,0,23,0
30,31,0,20,0
31,32,0,11,0
32,33,0,23,0
33,34,0,23,0
34,35,0,7,0
35,36,0,23,0
36,37,0,18,0
37,38,0,23,0
38,39,0,23,0
39,40,0,18,0
40,41,0,1,0
41,42,0,23,0
42,43,0,2,0
43,44,0,23,0
44,45,0,23,0
45,46,0,23,0
46,47,0,23,0
47,48,0,2,0
48,49,0,13,0
49,50,0,13,0
50,51,0,23,0
51,52,0,23,0
52,53,0,3,0
53,54,0,23,0
54,55,0,23,0
55,56,0,1,0
56,57,0,3,0
57,58,0,12,0
58,59,0,23,0
59,60,0,0,0
60,61,1,23,60
61,62,1,23,0
62,63,2,23,41
63,64,2,23,41
64,65,2,23,56
65,66,2,23,56
66,67,3,23,48
67,68,3,23,48
68,69,3,23,43
69,70,3,23,43
70,109,4,23,0
71,71,4,23,13
72,72,4,23,13
73,73,4,23,57
74,74,4,23,24
75,75,4,23,24
76,76,4,23,53
77,77,4,23,53
78,78,4,4,0
79,79,5,23,10
80,80,5,23,10
81,81,5,21,27
82,82,5,9,27
83,83,6,23,6
84,84,6,23,6
85,85,6,23,28
86,86,6,23,28
87,87,7,13,0
88,88,7,21,0
89,89,8,18,0
90,90,8,23,35
91,91,8,23,35
92,92,12,14,0
93,93,12,23,32
94,94,12,23,32
95,95,13,15,58
96,96,13,23,58
97,97,14,14,50
98,98,14,23,50
99,99,14,23,49
100,100,14,23,49
101,101,18,22,0
102,102,19,23,0
103,103,23,23,23
104,104,23,23,23
105,105,23,23,11
106,106,23,23,11
46 changes: 24 additions & 22 deletions scripts/test_ctc_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
import numpy as np
import pandas as pd

from deepcell_tracking.isbi_utils import trk_to_isbi

from traccuracy import run_metrics
from traccuracy.matchers import CTCMatcher
from traccuracy._tracking_graph import TrackingGraph
from traccuracy.matchers import CTCMatcher, IOUMatcher
from traccuracy.metrics import CTCMetrics, DivisionMetrics
from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data

from get_tracking_results import get_tracking_data, load_tracking_segmentation


def mark_potential_split(frames, last_frame, idx):
if frames.max() == last_frame: # object is tracked until the last frame
Expand Down Expand Up @@ -65,12 +61,10 @@ def extract_df_from_segmentation(segmentation):
return pred_tracks_df


def evaluate_tracking(labels, curr_lineages, segmentation_method):
seg = load_tracking_segmentation(segmentation_method)

def evaluate_tracking(raw, labels, seg, segmentation_method):
if os.path.isdir(seg): # for trackmate stardist
seg_T = load_ctc_data(
data_dir=seg,
data_dir=seg,
track_path=os.path.join(seg, 'res_track.txt'),
name=f'DynamicNuclearNet-{segmentation_method}'
)
Expand All @@ -84,11 +78,12 @@ def evaluate_tracking(labels, curr_lineages, segmentation_method):

breakpoint()

# calcuates node attributes for each detection
# calcuates node attributes for each detectionc
gt_nodes = _get_node_attributes(labels)

# converts inputs to isbi-tracking format - the version expected as inputs in traccuracy
gt_df = trk_to_isbi(curr_lineages, path=None)
# it's preconverted using "from deepcell_tracking.isbi_utils import trk_to_isbi"
gt_df = pd.read_csv("./gt_tracks.csv")

# creates graphs from ctc-type info (isbi-type? probably means the same thing)
gt_G = ctc_to_graph(gt_df, gt_nodes)
Expand All @@ -106,22 +101,29 @@ def evaluate_tracking(labels, curr_lineages, segmentation_method):
)
print(ctc_results)

breakpoint()

iou_results = run_metrics(
gt_data=gt_T,
pred_data=seg_T,
matcher=IOUMatcher(iou_threshold=0.1),
metrics=[DivisionMetrics(max_frame_buffer=0)],
)
print(iou_results)
def get_tracking_data(segmentation_method):
import h5py

with h5py.File("./tracking_micro_sam.h5", "r") as f:
raw = f["raw"][:]
labels = f["labels"][:]

if segmentation_method.startswith("vit"):
segmentation = f[f"segmentations/{segmentation_method}"][:]
else:
ROOT = "/scratch/projects/nim00007/sam/for_tracking"
result_dir = os.path.join(ROOT, "results")
segmentation = os.path.join(result_dir, "trackmate_stardist", "01_RES")

return raw, labels, segmentation


def main():
raw, labels, curr_lineages, chosen_frames = get_tracking_data()
segmentation_method = "trackmate_stardist"

segmentation_method = "vit_l_specialist"
evaluate_tracking(labels, curr_lineages, segmentation_method)
raw, labels, segmentation = get_tracking_data(segmentation_method)
evaluate_tracking(raw, labels, segmentation, segmentation_method)


if __name__ == "__main__":
Expand Down

0 comments on commit f5bdc60

Please sign in to comment.