Skip to content

Commit

Permalink
speed up observation() function by at least 10x for non live datasets…
Browse files Browse the repository at this point in the history
…; update handling of particle filter cases
  • Loading branch information
misko committed Nov 1, 2024
1 parent 89f02b2 commit 3988b29
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 200 deletions.
79 changes: 61 additions & 18 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import bisect
import os
import pickle
from contextlib import contextmanager
from functools import cache
from multiprocessing import Pool, cpu_count
from typing import Dict, List
Expand All @@ -19,28 +20,59 @@
from torch.utils.data import Dataset

from spf.dataset.rover_idxs import ( # v3rx_column_names,
v3rx_avg_phase_diff_idxs, v3rx_beamformer_start_idx, v3rx_column_names,
v3rx_gain_idxs, v3rx_rssi_idxs, v3rx_rx_pos_idxs, v3rx_rx_theta_idx,
v3rx_time_idxs)
v3rx_avg_phase_diff_idxs,
v3rx_beamformer_start_idx,
v3rx_column_names,
v3rx_gain_idxs,
v3rx_rssi_idxs,
v3rx_rx_pos_idxs,
v3rx_rx_theta_idx,
v3rx_time_idxs,
)
from spf.dataset.spf_generate import generate_session
from spf.dataset.v5_data import v5rx_2xf64_keys, v5rx_f64_keys
from spf.dataset.wall_array_v1_idxs import (v1_beamformer_start_idx,
v1_column_names, v1_time_idx,
v1_tx_pos_idxs)
from spf.dataset.wall_array_v1_idxs import (
v1_beamformer_start_idx,
v1_column_names,
v1_time_idx,
v1_tx_pos_idxs,
)
from spf.dataset.wall_array_v2_idxs import ( # v3rx_column_names,
v2_avg_phase_diff_idxs, v2_beamformer_start_idx, v2_column_names,
v2_gain_idxs, v2_rssi_idxs, v2_rx_pos_idxs, v2_rx_theta_idx, v2_time_idx,
v2_tx_pos_idxs)
from spf.plot.image_utils import (detector_positions_to_theta_grid,
labels_to_source_images, radio_to_image)
from spf.rf import (ULADetector, phase_diff_to_theta, pi_norm,
precompute_steering_vectors, segment_session,
segment_session_star, speed_of_light, torch_get_phase_diff,
torch_pi_norm)
v2_avg_phase_diff_idxs,
v2_beamformer_start_idx,
v2_column_names,
v2_gain_idxs,
v2_rssi_idxs,
v2_rx_pos_idxs,
v2_rx_theta_idx,
v2_time_idx,
v2_tx_pos_idxs,
)
from spf.plot.image_utils import (
detector_positions_to_theta_grid,
labels_to_source_images,
radio_to_image,
)
from spf.rf import (
ULADetector,
phase_diff_to_theta,
pi_norm,
precompute_steering_vectors,
segment_session,
segment_session_star,
speed_of_light,
torch_get_phase_diff,
torch_pi_norm,
)
from spf.sdrpluto.sdr_controller import rx_config_from_receiver_yaml
from spf.utils import (SEGMENTATION_VERSION, new_yarr_dataset,
rx_spacing_to_str, to_bin, zarr_open_from_lmdb_store,
zarr_shrink)
from spf.utils import (
SEGMENTATION_VERSION,
new_yarr_dataset,
rx_spacing_to_str,
to_bin,
zarr_open_from_lmdb_store,
zarr_shrink,
)


# from Stackoverflow
Expand Down Expand Up @@ -424,6 +456,15 @@ def v5_segmentation_mask(session):
return seg_mask[:, None]


@contextmanager
def v5spfdataset_manager(*args, **kwds):
ds = v5spfdataset(*args, **kwds)
try:
yield ds
finally:
ds.close()


class v5spfdataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -713,6 +754,8 @@ def close(self):
self.receiver_data = None
# try and close segmentation
self.segmentation = None
if self.precomputed_zarr is not None:
self.precomputed_zarr.store.close()
self.precomputed_zarr = None

def estimate_phi(self, data):
Expand Down
3 changes: 1 addition & 2 deletions spf/filters/ekf_dualradio_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import partial

import numpy as np
import torch
from filterpy.kalman import ExtendedKalmanFilter
from matplotlib import pyplot as plt

Expand All @@ -14,7 +13,7 @@
paired_hjacobian_phi_observation_from_theta_state,
residual,
)
from spf.rf import pi_norm, torch_pi_norm_pi
from spf.rf import pi_norm


class SPFPairedKalmanFilter(ExtendedKalmanFilter, SPFFilter):
Expand Down
1 change: 0 additions & 1 deletion spf/filters/ekf_single_radio_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import partial

import numpy as np
import torch
from filterpy.kalman import ExtendedKalmanFilter
from matplotlib import pyplot as plt

Expand Down
3 changes: 0 additions & 3 deletions spf/filters/filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from functools import cache

import numpy as np
Expand Down Expand Up @@ -305,8 +304,6 @@ def trajectory(
self.particles = create_gaussian_particles_xy(
mean, std, N, generator=self.generator
)
print(f"Particles checksumx: {self.particles.abs().mean()} {mean} {std} {N}")
logging.info(f"Particles checksum: {self.particles.abs().mean()}")
self.weights = torch.ones((N,), dtype=torch.float64) / N
trajectory = []
for idx in range(len(self.ds)):
Expand Down
6 changes: 6 additions & 0 deletions spf/filters/particle_dualradioXY_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def __init__(self, ds):

self.generator = torch.Generator()
self.generator.manual_seed(0)
if not self.ds.temp_file:
self.all_observations = torch.vstack(
[self.ds.mean_phase["r0"], self.ds.mean_phase["r1"]]
).T

def our_state(self, idx):
return torch.vstack(
Expand All @@ -41,6 +45,8 @@ def ground_truth_thetas(self):
return self.ds.craft_ground_truth_thetas

def observation(self, idx):
if not self.ds.temp_file:
return self.all_observations[idx]
return torch.concatenate(
[
self.ds[idx][0]["mean_phase_segmentation"].reshape(1),
Expand Down
61 changes: 40 additions & 21 deletions spf/filters/particle_dualradio_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import cache

import torch
from matplotlib import pyplot as plt

Expand All @@ -12,18 +10,48 @@
from spf.rf import torch_pi_norm_pi


@torch.jit.script
def pf_single_theta_dual_radio_update(weights, particles, z, empirical_dist, offsets):
z0 = theta_phi_to_p_vec(
torch_pi_norm_pi(particles[:, 0] - offsets[0]),
z[0],
empirical_dist[0],
)
z1 = theta_phi_to_p_vec(
torch_pi_norm_pi(particles[:, 0] - offsets[1]),
z[1],
empirical_dist[1],
)
weights = weights * z0 * z1 + 1.0e-30 # avoid round-off to zero
return weights / torch.sum(weights) # normalize


class PFSingleThetaDualRadio(ParticleFilter):
def __init__(self, ds):
self.ds = ds
self.offsets = [
ds.yaml_config["receivers"][0]["theta-in-pis"] * torch.pi,
ds.yaml_config["receivers"][1]["theta-in-pis"] * torch.pi,
]
self.offsets = torch.tensor(
[
ds.yaml_config["receivers"][0]["theta-in-pis"] * torch.pi,
ds.yaml_config["receivers"][1]["theta-in-pis"] * torch.pi,
]
)

self.generator = torch.Generator()
self.generator.manual_seed(0)
self.cached_empirical_dist = torch.vstack(
[
self.ds.get_empirical_dist(0).T.unsqueeze(0),
self.ds.get_empirical_dist(1).T.unsqueeze(0),
]
)
if not self.ds.temp_file:
self.all_observations = torch.vstack(
[self.ds.mean_phase["r0"], self.ds.mean_phase["r1"]]
).T

def observation(self, idx):
if not self.ds.temp_file:
return self.all_observations[idx]
return torch.concatenate(
[
self.ds[idx][0]["mean_phase_segmentation"].reshape(1),
Expand All @@ -32,10 +60,6 @@ def observation(self, idx):
axis=0,
)

@cache
def cached_empirical_dist(self, rx_idx):
return self.ds.get_empirical_dist(rx_idx).T

def fix_particles(self):
self.particles[:, 0] = torch_pi_norm_pi(self.particles[:, 0])

Expand All @@ -46,18 +70,13 @@ def predict(self, our_state, dt, noise_std):
add_noise(self.particles, noise_std=noise_std, generator=self.generator)

def update(self, z):
self.weights *= theta_phi_to_p_vec(
torch_pi_norm_pi(self.particles[:, 0] - self.offsets[0]),
z[0],
self.cached_empirical_dist(0),
)
self.weights *= theta_phi_to_p_vec(
torch_pi_norm_pi(self.particles[:, 0] - self.offsets[1]),
z[1],
self.cached_empirical_dist(1),
self.weights = pf_single_theta_dual_radio_update(
self.weights,
self.particles,
z,
self.cached_empirical_dist,
self.offsets,
)
self.weights += 1.0e-30 # avoid round-off to zero
self.weights /= torch.sum(self.weights) # normalize

def metrics(self, trajectory):
return dual_radio_mse_theta_metrics(
Expand Down
Loading

0 comments on commit 3988b29

Please sign in to comment.