Skip to content

Commit

Permalink
Add metric evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed May 6, 2024
1 parent 68e5748 commit 51352f1
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 38 deletions.
53 changes: 42 additions & 11 deletions scripts/get_tracking_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,42 @@

def load_tracking_segmentation(experiment):
result_dir = os.path.join(ROOT, "results")
if experiment == "vit_l":
seg_path = os.path.join(result_dir, "vit_l.tif")
elif experiment == "vit_l_lm":
seg_path = os.path.join(result_dir, "vit_l_lm.tif")
elif experiment == "vit_l_specialist":
seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif")
elif experiment == "trackmate_stardist":
seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif")
else:
raise ValueError(experiment)

return imageio.imread(seg_path)
if experiment.startswith("vit"):
if experiment == "vit_l":
seg_path = os.path.join(result_dir, "vit_l.tif")
seg = imageio.imread(seg_path)
# HACK
ignore_labels = [8, 44, 57, 102, 50]

elif experiment == "vit_l_lm":
seg_path = os.path.join(result_dir, "vit_l_lm.tif")
seg = imageio.imread(seg_path)
# HACK
ignore_labels = []

elif experiment == "vit_l_specialist":
seg_path = os.path.join(result_dir, "vit_l_lm_specialist.tif")
seg = imageio.imread(seg_path)
# HACK
ignore_labels = [88, 45, 30, 46]

# elif experiment == "trackmate_stardist":
# seg_path = os.path.join(result_dir, "trackmate_stardist", "every_3rd_fr_result.tif")
# seg = imageio.imread(seg_path)

else:
raise ValueError(experiment)

# HACK:
# we remove some labels as they have a weird lineage, is creating issues for creating the graph
# (e.g. frames where the object exists: 1, 2, 4, 5, 6)
seg[np.isin(seg, ignore_labels)] = 0

return seg

else: # return the result directory for stardist
return os.path.join(result_dir, "trackmate_stardist", "01_RES")


def check_tracking_results(raw, labels, curr_lineages, chosen_frames):
Expand Down Expand Up @@ -93,4 +117,11 @@ def get_tracking_data():
curr_frames = v["frames"]
v["frames"] = [frmaps[frval] for frval in curr_frames if frval in chosen_frames]

# HACK:
# we remove label with id 62 as it has a weird lineage, is creating issues for creating the graph
ignore_labels = [62, 87, 92, 99, 58]
labels[np.isin(labels, ignore_labels)] = 0
for _label in ignore_labels:
curr_lineages.pop(_label)

return raw, labels, curr_lineages, chosen_frames
108 changes: 81 additions & 27 deletions scripts/test_ctc_metric.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,56 @@
import os
import numpy as np
import pandas as pd

from deepcell_tracking.isbi_utils import trk_to_isbi

from traccuracy.loaders._ctc import _get_node_attributes
from traccuracy import run_metrics
from traccuracy._tracking_graph import TrackingGraph
from traccuracy.matchers import CTCMatcher, IOUMatcher
from traccuracy.metrics import CTCMetrics, DivisionMetrics
from traccuracy.loaders._ctc import _get_node_attributes, ctc_to_graph, _check_ctc, load_ctc_data

from get_tracking_results import get_tracking_data, load_tracking_segmentation


def mark_potential_split(frames, last_frame, idx):
if frames.max() == last_frame: # object is tracked until the last frame
split_frame = None # they can't split in this case
prev_parent_id = None
else: # object either goes out of frame or splits
split_frame = frames.max() # let's assume that it splits, we will know if it does or not
prev_parent_id = idx
return split_frame, prev_parent_id


def extract_df_from_segmentation(segmentation):
track_ids = np.unique(segmentation)[1:]
last_frame = segmentation.shape[0] - 1

all_tracks = []
splits = 0
for idx in track_ids:
prev_parent_id = None

for idx in track_ids:
frames = np.unique(np.where(segmentation == idx)[0])

if frames.min() == 0: # object starts at first frame
if frames.max() == last_frame: # object is tracked until the last frame
pid = 0
have_fam = None # they can't split in this case
else: # object either goes out of frame or splits
pid = 0
have_fam = frames.max() # let's assume that it splits, we will know if it does or not
pid = 0
split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx)

else:
if have_fam is not None: # takes the parent information from above
pid = have_fam
splits += 1

if splits > 2: # assumes every mother cell splits into two daughter cells
print("The mother cell has made enough daughter splits, hence this is a new object.")
splits = 0
# pid = 0 # this is the case where an objects appears at nth frame and has no parent id
if split_frame is not None: # takes the parent information from above
# have fam is the end frame of the potential parent, so our frame has to be the next frame
if split_frame + 1 == frames.min():
pid = prev_parent_id

# otherwise we just have some track that starts so it's not the child
else:
pid = 0
split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx)

else:
pid = 0 # assumes that it was an object that started at a random frame
split_frame, prev_parent_id = mark_potential_split(frames, last_frame, idx)

track_dict = {
"Cell_ID": idx,
Expand All @@ -44,30 +59,69 @@ def extract_df_from_segmentation(segmentation):
"Parent_ID": pid,
}

print(track_dict)
all_tracks.append(track_dict)
all_tracks.append(pd.DataFrame.from_dict([track_dict]))

breakpoint()
pred_tracks_df = pd.concat(all_tracks)
return pred_tracks_df


def evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method):
def evaluate_tracking(labels, curr_lineages, segmentation_method):
seg = load_tracking_segmentation(segmentation_method)

if os.path.isdir(seg): # for trackmate stardist
seg_T = load_ctc_data(
data_dir=seg,
track_path=os.path.join(seg, 'res_track.txt'),
name=f'DynamicNuclearNet-{segmentation_method}'
)

else: # for micro-sam
seg_nodes = _get_node_attributes(seg)
seg_df = extract_df_from_segmentation(seg)
seg_G = ctc_to_graph(seg_df, seg_nodes)
_check_ctc(seg_df, seg_nodes, seg)
seg_T = TrackingGraph(seg_G, segmentation=seg, name=f"DynamicNuclearNet-{segmentation_method}")

breakpoint()

# calcuates node attributes for each detection
gt_df = _get_node_attributes(labels)
seg_df = _get_node_attributes(seg)
gt_nodes = _get_node_attributes(labels)

# converts inputs to isbi-tracking format - the version expected as inputs in traccuracy
gt_df = trk_to_isbi(curr_lineages, path=None)

# creates graphs from ctc-type info (isbi-type? probably means the same thing)
gt_G = ctc_to_graph(gt_df, gt_nodes)

# OPTIONAL: This tests if inputs (images, dfs and node attributes) to create tracking graphs are as expected
_check_ctc(gt_df, gt_nodes, labels)

gt_T = TrackingGraph(gt_G, segmentation=labels, name="DynamicNuclearNet-GT")

ctc_results = run_metrics(
gt_data=gt_T,
pred_data=seg_T,
matcher=CTCMatcher(),
metrics=[CTCMetrics(), DivisionMetrics(max_frame_buffer=0)],
)
print(ctc_results)

# converts inputs to isbi-track format - the version expected as inputs in traccuracy
output = trk_to_isbi(curr_lineages, path=None)
breakpoint()

df = extract_df_from_segmentation(seg)
iou_results = run_metrics(
gt_data=gt_T,
pred_data=seg_T,
matcher=IOUMatcher(iou_threshold=0.1),
metrics=[DivisionMetrics(max_frame_buffer=0)],
)
print(iou_results)


def main():
raw, labels, curr_lineages, chosen_frames = get_tracking_data()

segmentation_method = "vit_l_specialist"
evaluate_tracking(raw, labels, curr_lineages, chosen_frames, segmentation_method)
evaluate_tracking(labels, curr_lineages, segmentation_method)


if __name__ == "__main__":
Expand Down

0 comments on commit 51352f1

Please sign in to comment.