Skip to content

Commit

Permalink
Add threshold and percent bad frames to ZDriftMetrics definition
Browse files Browse the repository at this point in the history
  • Loading branch information
kushalbakshi committed Dec 6, 2024
1 parent 569e0cd commit 44ec968
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions element_calcium_imaging/imaging_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,20 +298,24 @@ class ZDriftMetrics(dj.Computed):
ProcessingTask (foreign key): Primary key from ProcessingTask.
ZDriftParamSet (foreign key): Primary key from ZDriftParamSet.
z_drift (longblob): Amount of drift in microns per frame in Z direction.
bad_frames_threshold (int): Drift threshold in microns where frames are excluded from registration.
percent_bad_frames (float): Percentage of frames with z-drift exceeding the threshold.
"""

definition = """
-> ProcessingTask
---
bad_frames=NULL: longblob # `True` if any value in z_drift > threshold from drift_params.
z_drift: longblob # Amount of drift in microns per frame in Z direction.
bad_frames_threshold: int # Drift threshold in microns where frames are excluded from registration.
percent_bad_frames: float # Percentage of frames with z-drift exceeding the threshold.
"""

_default_params = {
"pad_length": 5,
"slice_interval": 1,
"num_scans": 5,
"bad_frames_threshold": 3
"bad_frames_threshold": 3,
}

def make(self, key):
Expand All @@ -323,9 +327,7 @@ def _make_taper(size, width):
return np.convolve(m, k, mode="full") / k.sum()

nchannels = (scan.ScanInfo & key).fetch1("nchannels")
params = (ProcessingTask * ProcessingParamSet & key).fetch1(
"params"
)
params = (ProcessingTask * ProcessingParamSet & key).fetch1("params")
drift_params = params.get("ZDRIFT_PARAMS", self._default_params)

# use the same channel specified in ProcessingParamSet for this task
Expand Down Expand Up @@ -433,10 +435,18 @@ def _make_taper(size, width):
"slice_interval"
]

bad_frames_idx = np.where(np.abs(drift) >= drift_params["bad_frames_threshold"])[0]
bad_frames_idx = np.where(
np.abs(drift) >= drift_params["bad_frames_threshold"]
)[0]

self.insert1(
dict(**key, bad_frames=bad_frames_idx, z_drift=drift),
dict(
**key,
z_drift=drift,
bad_frames=bad_frames_idx,
bad_frames_threshold=drift_params["bad_frames_threshold"],
percent_bad_frames=len(bad_frames_idx) / len(drift) * 100,
),
)


Expand Down Expand Up @@ -550,9 +560,7 @@ def make(self, key):
"processing_method"
)

params = (ProcessingTask * ProcessingParamSet & key).fetch1(
"params"
)
params = (ProcessingTask * ProcessingParamSet & key).fetch1("params")
params.pop("ZDRIFT_PARAMS", None)

if method == "suite2p":
Expand Down Expand Up @@ -678,7 +686,8 @@ def make(self, key):
}
for f in pathlib.Path(output_dir).rglob("*")
if f.is_file()
], ignore_extra_fields=True
],
ignore_extra_fields=True,
)


Expand Down

0 comments on commit 44ec968

Please sign in to comment.