Skip to content

Commit

Permalink
nuba optimize segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Apr 7, 2024
1 parent a414d8c commit 420c747
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 43 deletions.
2 changes: 2 additions & 0 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import yaml
from compress_pickle import load
from deepdiff import DeepDiff
from numba import njit
from torch.utils.data import Dataset

from spf.dataset.rover_idxs import ( # v3rx_column_names,
Expand Down Expand Up @@ -49,6 +50,7 @@
from spf.rf import ULADetector


@njit
def pi_norm(x):
return ((x + np.pi) % (2 * np.pi)) - np.pi

Expand Down
138 changes: 95 additions & 43 deletions spf/sdrpluto/sdr_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import adi
import matplotlib.pyplot as plt
import numpy as np
from numba import jit, njit
from tqdm import tqdm

from spf.dataset.spf_dataset import pi_norm
Expand Down Expand Up @@ -661,82 +662,130 @@ def setup_rxtx_and_phase_calibration(
return None, None


@njit
def circular_diff_to_mean(angles, mean):
a = np.abs(mean - angles) % (2 * np.pi)
b = 2 * np.pi - a
dists = np.empty(a.shape[0])
for i in range(a.shape[0]):
dists[i] = min(a[i], b[i])
return dists


# returns circular_stddev and trimmed cricular stddev
@njit
def circular_stddev(v, u, trim=50.0):
diff_from_mean = np.abs(
np.vstack([v - (u - 2 * np.pi), v - u, v - (u + 2 * np.pi)])
).min(axis=0)
_diff_from_mean = diff_from_mean[
diff_from_mean <= np.percentile(diff_from_mean, 100.0 - trim)
]
stddev = np.sqrt((diff_from_mean**2).sum() / (diff_from_mean.shape[0] - 1))
diff_from_mean = circular_diff_to_mean(angles=v, mean=u)

diff_from_mean_squared = diff_from_mean**2
stddev = np.sqrt(diff_from_mean_squared.sum() / (diff_from_mean.shape[0] - 1))

mask = diff_from_mean <= np.percentile(diff_from_mean, 100.0 - trim)
_diff_from_mean_squared = diff_from_mean_squared[mask]

trimmed_stddev = np.sqrt(
(_diff_from_mean**2).sum() / (_diff_from_mean.shape[0] - 1)
_diff_from_mean_squared.sum() / (_diff_from_mean_squared.shape[0] - 1)
)
return stddev, trimmed_stddev


# returns circular mean and trimmed circular mean
@njit
def circular_mean(angles, trim=50.0):
# n = angles.shape[0]
# assert angles.ndim == 1 or np.prod(a.shape[1:]) == 1
cm = np.arctan2(np.sin(angles).sum(), np.cos(angles).sum()) % (2 * np.pi)
dists = np.vstack([2 * np.pi - np.abs(cm - angles), np.abs(cm - angles)]).min(
axis=0
)
_angles = angles[dists < np.percentile(dists, 100.0 - trim)]
_cm = np.arctan2(np.sin(_angles).sum(), np.cos(_angles).sum()) % (2 * np.pi)
_sin_angles = np.sin(angles)
_cos_angles = np.cos(angles)
cm = np.arctan2(_sin_angles.sum(), _cos_angles.sum()) % (2 * np.pi)

##non JIT
# dists = np.vstack((2 * np.pi - np.abs(cm - angles), np.abs(cm - angles))).min(
# axis=0
# )

# JIT version
dists = circular_diff_to_mean(angles=angles, mean=cm)

mask = dists < np.percentile(dists, 100.0 - trim)
_cm = np.arctan2(_sin_angles[mask].sum(), _cos_angles[mask].sum()) % (2 * np.pi)
return pi_norm(cm), pi_norm(_cm)


def chunkify_array_start_end_idxs(v, window_size, stride):
assert (v.shape[0] - window_size) % stride == 0
steps = 1 + (v.shape[0] - window_size) // stride
for step in range(steps):
yield step * stride, step * stride + window_size
def segment_session(
z,
receiver,
session_idx,
window_size,
stride,
trim,
mean_diff_threshold,
max_stddev_threshold,
):
signal_matrix = z.receivers[receiver].signal_matrix[session_idx]
pd = get_phase_diff(signal_matrix)
return simple_segment(
pd,
window_size=window_size,
stride=stride,
trim=trim,
mean_diff_threshold=mean_diff_threshold, #
max_stddev_threshold=max_stddev_threshold, # just eyeballed this
)


@njit
def windowed_trimmed_circular_mean_and_stddev(v, window_size, stride, trim=50.0):
windows = []
for start_idx, end_idx in chunkify_array_start_end_idxs(
v, window_size=window_size, stride=stride
):
assert (v.shape[0] - window_size) % stride == 0
n_steps = 1 + (v.shape[0] - window_size) // stride
# for step in range(steps):
# yield
step_stats = np.zeros((n_steps, 2), dtype=np.float64)
step_idxs = np.zeros((n_steps, 2), dtype=np.int64)
steps = np.arange(n_steps)
# start_idx, end_idx
step_idxs[:, 0] = steps * stride
step_idxs[:, 1] = step_idxs[:, 0] + window_size
for step in range(n_steps):
start_idx, end_idx = step_idxs[step]
_v = v[start_idx:end_idx]
assert _v.shape[0] > 0
trimmed_cm = circular_mean(_v, trim=trim)[1]
windows.append(
{
"start_idx": start_idx,
"end_idx": end_idx,
"mean": trimmed_cm,
"stddev": circular_stddev(_v, trimmed_cm, trim=trim)[1],
}
)
return windows
step_stats[step, 0] = trimmed_cm
step_stats[step, 1] = circular_stddev(_v, trimmed_cm, trim=trim)[1]

return step_idxs, step_stats


def simple_segment(
v, window_size, stride, trim, mean_diff_threshold, max_stddev_threshold
):
valid_windows = []
for window in windowed_trimmed_circular_mean_and_stddev(
window_idxs, window_stats = windowed_trimmed_circular_mean_and_stddev(
v, window_size=window_size, stride=stride, trim=trim
):
)
for step in range(window_idxs.shape[0]):
start_idx, end_idx = window_idxs[step]
mean, stddev = window_stats[step]
# is this a valid region
if window["stddev"] < max_stddev_threshold:
if stddev < max_stddev_threshold:
if (
len(valid_windows) > 0 # if this is the first window
and valid_windows[-1]["end_idx"]
>= window["start_idx"] # check for overlap
and abs(valid_windows[-1]["mean"] - window["mean"])
and valid_windows[-1]["end_idx"] >= start_idx # check for overlap
and abs(valid_windows[-1]["mean"] - mean)
<= mean_diff_threshold # if not within tolerance
):
# append to previous window
valid_windows[-1]["end_idx"] = window["end_idx"]
valid_windows[-1]["mean"] = window["mean"] # recompute later
valid_windows[-1]["end_idx"] = end_idx
valid_windows[-1]["mean"] = mean # recompute later
else:
# add a new window
valid_windows.append(window.copy())
valid_windows.append(
{
"start_idx": int(start_idx),
"end_idx": int(end_idx),
"mean": mean,
"stddev": stddev,
}
)
# re-compute final stats as they are off
for window in valid_windows:
_v = v[window["start_idx"] : window["end_idx"]]
Expand All @@ -745,10 +794,13 @@ def simple_segment(
return valid_windows


@njit
def get_phase_diff(signal_matrix):
return pi_norm(np.angle(signal_matrix[0]) - np.angle(signal_matrix[1]))
diffs = (np.angle(signal_matrix[0]) - np.angle(signal_matrix[1])).astype(np.float64)
return pi_norm(diffs)


@njit
def get_avg_phase(signal_matrix, trim=0.0):
return circular_mean(get_phase_diff(signal_matrix=signal_matrix), trim=50.0)

Expand Down

0 comments on commit 420c747

Please sign in to comment.