From f5bdc60ed61675aa3d69751e95f445602605fa5c Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 7 May 2024 15:42:40 +0200 Subject: [PATCH] Update scripts to access stored gt tracks --- scripts/get_tracking_results.py | 26 +------- scripts/gt_tracks.csv | 108 ++++++++++++++++++++++++++++++++ scripts/test_ctc_metric.py | 46 +++++++------- 3 files changed, 133 insertions(+), 47 deletions(-) create mode 100644 scripts/gt_tracks.csv diff --git a/scripts/get_tracking_results.py b/scripts/get_tracking_results.py index 17d11d2..7d64fa0 100644 --- a/scripts/get_tracking_results.py +++ b/scripts/get_tracking_results.py @@ -17,34 +17,17 @@ def load_tracking_segmentation(experiment): if experiment.startswith("vit"): if experiment == "vit_l": seg_path = os.path.join(result_dir, "vit_l.tif") - seg = imageio.imread(seg_path) - # HACK - ignore_labels = [8, 44, 57, 102, 50] - elif experiment == "vit_l_lm": seg_path = os.path.join(result_dir, "vit_l_lm.tif") - seg = imageio.imread(seg_path) - # HACK - ignore_labels = [] - elif experiment == "vit_l_specialist": seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif") - seg = imageio.imread(seg_path) - # HACK - ignore_labels = [88, 45, 30, 46] - # elif experiment == "trackmate_stardist": # seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif") - # seg = imageio.imread(seg_path) else: raise ValueError(experiment) - # HACK: - # we remove some labels as they have a weird lineage, is creating issues for creating the graph - # (e.g. frames where the object exists: 1, 2, 4, 5, 6) - seg[np.isin(seg, ignore_labels)] = 0 - + seg = imageio.imread(seg_path) return seg else: # return the result directory for stardist @@ -117,11 +100,4 @@ def get_tracking_data(): curr_frames = v["frames"] v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames] - # HACK: - # we remove label with id 62 as it has a weird lineage, is creating issues for creating the graph - ignore_labels = [62, 87, 92, 99, 58] - labels[np.isin(labels, ignore_labels)] = 0 - for _label in ignore_labels: - curr_lineages.pop(_label) - return raw, labels, curr_lineages, chosen_frames diff --git a/scripts/gt_tracks.csv b/scripts/gt_tracks.csv new file mode 100644 index 0000000..1d0065c --- /dev/null +++ b/scripts/gt_tracks.csv @@ -0,0 +1,108 @@ +,Cell_ID,Start,End,Parent_ID +0,1,0,23,0 +1,2,0,23,0 +2,3,0,23,0 +3,4,0,23,0 +4,5,0,20,0 +5,6,0,5,0 +6,7,0,23,0 +7,8,0,23,0 +8,9,0,23,0 +9,10,0,4,0 +10,11,0,22,0 +11,12,0,23,0 +12,13,0,3,0 +13,14,0,23,0 +14,15,0,3,0 +15,16,0,23,0 +16,17,0,23,0 +17,18,0,23,0 +18,19,0,3,0 +19,20,0,23,0 +20,21,0,23,0 +21,22,0,23,0 +22,23,0,22,0 +23,24,0,3,0 +24,25,0,23,0 +25,26,0,23,0 +26,27,0,4,0 +27,28,0,5,0 +28,29,0,23,0 +29,30,0,23,0 +30,31,0,20,0 +31,32,0,11,0 +32,33,0,23,0 +33,34,0,23,0 +34,35,0,7,0 +35,36,0,23,0 +36,37,0,18,0 +37,38,0,23,0 +38,39,0,23,0 +39,40,0,18,0 +40,41,0,1,0 +41,42,0,23,0 +42,43,0,2,0 +43,44,0,23,0 +44,45,0,23,0 +45,46,0,23,0 +46,47,0,23,0 +47,48,0,2,0 +48,49,0,13,0 +49,50,0,13,0 +50,51,0,23,0 +51,52,0,23,0 +52,53,0,3,0 +53,54,0,23,0 +54,55,0,23,0 +55,56,0,1,0 +56,57,0,3,0 +57,58,0,12,0 +58,59,0,23,0 +59,60,0,0,0 +60,61,1,23,60 +61,62,1,23,0 +62,63,2,23,41 +63,64,2,23,41 +64,65,2,23,56 +65,66,2,23,56 +66,67,3,23,48 +67,68,3,23,48 +68,69,3,23,43 +69,70,3,23,43 +70,109,4,23,0 +71,71,4,23,13 +72,72,4,23,13 +73,73,4,23,57 +74,74,4,23,24 +75,75,4,23,24 +76,76,4,23,53 +77,77,4,23,53 +78,78,4,4,0 +79,79,5,23,10 +80,80,5,23,10 +81,81,5,21,27 +82,82,5,9,27 +83,83,6,23,6 +84,84,6,23,6 +85,85,6,23,28 +86,86,6,23,28 +87,87,7,13,0 +88,88,7,21,0 +89,89,8,18,0 +90,90,8,23,35 +91,91,8,23,35 +92,92,12,14,0 +93,93,12,23,32 +94,94,12,23,32 +95,95,13,15,58 +96,96,13,23,58 +97,97,14,14,50 +98,98,14,23,50 +99,99,14,23,49 +100,100,14,23,49 +101,101,18,22,0 +102,102,19,23,0 +103,103,23,23,23 +104,104,23,23,23 +105,105,23,23,11 +106,106,23,23,11 diff --git a/scripts/test_ctc_metric.py b/scripts/test_ctc_metric.py index dc06637..fc6e9cf 100644 --- a/scripts/test_ctc_metric.py +++ b/scripts/test_ctc_metric.py @@ -2,16 +2,12 @@ import numpy as np import pandas as pd -from deepcell_tracking.isbi_utils import trk_to_isbi - from traccuracy import run_metrics +from traccuracy.matchers import CTCMatcher from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers import CTCMatcher, IOUMatcher from traccuracy.metrics import CTCMetrics, DivisionMetrics from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data -from get_tracking_results import get_tracking_data, load_tracking_segmentation - def mark_potential_split(frames, last_frame, idx): if frames.max() == last_frame: # object is tracked until the last frame @@ -65,12 +61,10 @@ def extract_df_from_segmentation(segmentation): return pred_tracks_df -def evaluate_tracking(labels, curr_lineages, segmentation_method): - seg = load_tracking_segmentation(segmentation_method) - +def evaluate_tracking(raw, labels, seg, segmentation_method): if os.path.isdir(seg): # for trackmate stardist seg_T = load_ctc_data( - data_dir=seg, + data_dir=seg, track_path=os.path.join(seg, 'res_track.txt'), name=f'DynamicNuclearNet-{segmentation_method}' ) @@ -84,11 +78,12 @@ def evaluate_tracking(labels, curr_lineages, segmentation_method): breakpoint() - # calcuates node attributes for each detection + # calcuates node attributes for each detectionc gt_nodes = _get_node_attributes(labels) # converts inputs to isbi-tracking format - the version expected as inputs in traccuracy - gt_df = trk_to_isbi(curr_lineages, path=None) + # it's preconverted using "from deepcell_tracking.isbi_utils import trk_to_isbi" + gt_df = pd.read_csv("./gt_tracks.csv") # creates graphs from ctc-type info (isbi-type? probably means the same thing) gt_G = ctc_to_graph(gt_df, gt_nodes) @@ -106,22 +101,29 @@ def evaluate_tracking(labels, curr_lineages, segmentation_method): ) print(ctc_results) - breakpoint() - iou_results = run_metrics( - gt_data=gt_T, - pred_data=seg_T, - matcher=IOUMatcher(iou_threshold=0.1), - metrics=[DivisionMetrics(max_frame_buffer=0)], - ) - print(iou_results) +def get_tracking_data(segmentation_method): + import h5py + + with h5py.File("./tracking_micro_sam.h5", "r") as f: + raw = f["raw"][:] + labels = f["labels"][:] + + if segmentation_method.startswith("vit"): + segmentation = f[f"segmentations/{segmentation_method}"][:] + else: + ROOT = "/scratch/projects/nim00007/sam/for_tracking" + result_dir = os.path.join(ROOT, "results") + segmentation = os.path.join(result_dir, "trackmate_stardist", "01_RES") + + return raw, labels, segmentation def main(): - raw, labels, curr_lineages, chosen_frames = get_tracking_data() + segmentation_method = "trackmate_stardist" - segmentation_method = "vit_l_specialist" - evaluate_tracking(labels, curr_lineages, segmentation_method) + raw, labels, segmentation = get_tracking_data(segmentation_method) + evaluate_tracking(raw, labels, segmentation, segmentation_method) if __name__ == "__main__":