diff --git a/spf/dataset/spf_dataset.py b/spf/dataset/spf_dataset.py index 408d41c..32149e2 100644 --- a/spf/dataset/spf_dataset.py +++ b/spf/dataset/spf_dataset.py @@ -5,6 +5,7 @@ import bisect import os import pickle +from contextlib import contextmanager from functools import cache from multiprocessing import Pool, cpu_count from typing import Dict, List @@ -19,28 +20,59 @@ from torch.utils.data import Dataset from spf.dataset.rover_idxs import ( # v3rx_column_names, - v3rx_avg_phase_diff_idxs, v3rx_beamformer_start_idx, v3rx_column_names, - v3rx_gain_idxs, v3rx_rssi_idxs, v3rx_rx_pos_idxs, v3rx_rx_theta_idx, - v3rx_time_idxs) + v3rx_avg_phase_diff_idxs, + v3rx_beamformer_start_idx, + v3rx_column_names, + v3rx_gain_idxs, + v3rx_rssi_idxs, + v3rx_rx_pos_idxs, + v3rx_rx_theta_idx, + v3rx_time_idxs, +) from spf.dataset.spf_generate import generate_session from spf.dataset.v5_data import v5rx_2xf64_keys, v5rx_f64_keys -from spf.dataset.wall_array_v1_idxs import (v1_beamformer_start_idx, - v1_column_names, v1_time_idx, - v1_tx_pos_idxs) +from spf.dataset.wall_array_v1_idxs import ( + v1_beamformer_start_idx, + v1_column_names, + v1_time_idx, + v1_tx_pos_idxs, +) from spf.dataset.wall_array_v2_idxs import ( # v3rx_column_names, - v2_avg_phase_diff_idxs, v2_beamformer_start_idx, v2_column_names, - v2_gain_idxs, v2_rssi_idxs, v2_rx_pos_idxs, v2_rx_theta_idx, v2_time_idx, - v2_tx_pos_idxs) -from spf.plot.image_utils import (detector_positions_to_theta_grid, - labels_to_source_images, radio_to_image) -from spf.rf import (ULADetector, phase_diff_to_theta, pi_norm, - precompute_steering_vectors, segment_session, - segment_session_star, speed_of_light, torch_get_phase_diff, - torch_pi_norm) + v2_avg_phase_diff_idxs, + v2_beamformer_start_idx, + v2_column_names, + v2_gain_idxs, + v2_rssi_idxs, + v2_rx_pos_idxs, + v2_rx_theta_idx, + v2_time_idx, + v2_tx_pos_idxs, +) +from spf.plot.image_utils import ( + detector_positions_to_theta_grid, + labels_to_source_images, + radio_to_image, +) +from spf.rf import ( + ULADetector, + phase_diff_to_theta, + pi_norm, + precompute_steering_vectors, + segment_session, + segment_session_star, + speed_of_light, + torch_get_phase_diff, + torch_pi_norm, +) from spf.sdrpluto.sdr_controller import rx_config_from_receiver_yaml -from spf.utils import (SEGMENTATION_VERSION, new_yarr_dataset, - rx_spacing_to_str, to_bin, zarr_open_from_lmdb_store, - zarr_shrink) +from spf.utils import ( + SEGMENTATION_VERSION, + new_yarr_dataset, + rx_spacing_to_str, + to_bin, + zarr_open_from_lmdb_store, + zarr_shrink, +) # from Stackoverflow @@ -424,6 +456,15 @@ def v5_segmentation_mask(session): return seg_mask[:, None] +@contextmanager +def v5spfdataset_manager(*args, **kwds): + ds = v5spfdataset(*args, **kwds) + try: + yield ds + finally: + ds.close() + + class v5spfdataset(Dataset): def __init__( self, @@ -713,6 +754,8 @@ def close(self): self.receiver_data = None # try and close segmentation self.segmentation = None + if self.precomputed_zarr is not None: + self.precomputed_zarr.store.close() self.precomputed_zarr = None def estimate_phi(self, data): diff --git a/spf/filters/ekf_dualradio_filter.py b/spf/filters/ekf_dualradio_filter.py index 2784c55..c1ac2d6 100644 --- a/spf/filters/ekf_dualradio_filter.py +++ b/spf/filters/ekf_dualradio_filter.py @@ -1,7 +1,6 @@ from functools import partial import numpy as np -import torch from filterpy.kalman import ExtendedKalmanFilter from matplotlib import pyplot as plt @@ -14,7 +13,7 @@ paired_hjacobian_phi_observation_from_theta_state, residual, ) -from spf.rf import pi_norm, torch_pi_norm_pi +from spf.rf import pi_norm class SPFPairedKalmanFilter(ExtendedKalmanFilter, SPFFilter): diff --git a/spf/filters/ekf_single_radio_filter.py b/spf/filters/ekf_single_radio_filter.py index 9e44b4f..88fc3fb 100644 --- a/spf/filters/ekf_single_radio_filter.py +++ b/spf/filters/ekf_single_radio_filter.py @@ -1,7 +1,6 @@ from functools import partial import numpy as np -import torch from filterpy.kalman import ExtendedKalmanFilter from matplotlib import pyplot as plt diff --git a/spf/filters/filters.py b/spf/filters/filters.py index e85aeca..86527e2 100644 --- a/spf/filters/filters.py +++ b/spf/filters/filters.py @@ -1,4 +1,3 @@ -import logging from functools import cache import numpy as np @@ -305,8 +304,6 @@ def trajectory( self.particles = create_gaussian_particles_xy( mean, std, N, generator=self.generator ) - print(f"Particles checksumx: {self.particles.abs().mean()} {mean} {std} {N}") - logging.info(f"Particles checksum: {self.particles.abs().mean()}") self.weights = torch.ones((N,), dtype=torch.float64) / N trajectory = [] for idx in range(len(self.ds)): diff --git a/spf/filters/particle_dualradioXY_filter.py b/spf/filters/particle_dualradioXY_filter.py index 7915412..36e4a7c 100644 --- a/spf/filters/particle_dualradioXY_filter.py +++ b/spf/filters/particle_dualradioXY_filter.py @@ -20,6 +20,10 @@ def __init__(self, ds): self.generator = torch.Generator() self.generator.manual_seed(0) + if not self.ds.temp_file: + self.all_observations = torch.vstack( + [self.ds.mean_phase["r0"], self.ds.mean_phase["r1"]] + ).T def our_state(self, idx): return torch.vstack( @@ -41,6 +45,8 @@ def ground_truth_thetas(self): return self.ds.craft_ground_truth_thetas def observation(self, idx): + if not self.ds.temp_file: + return self.all_observations[idx] return torch.concatenate( [ self.ds[idx][0]["mean_phase_segmentation"].reshape(1), diff --git a/spf/filters/particle_dualradio_filter.py b/spf/filters/particle_dualradio_filter.py index 83d3c6d..0e32bb1 100644 --- a/spf/filters/particle_dualradio_filter.py +++ b/spf/filters/particle_dualradio_filter.py @@ -1,5 +1,3 @@ -from functools import cache - import torch from matplotlib import pyplot as plt @@ -12,18 +10,48 @@ from spf.rf import torch_pi_norm_pi +@torch.jit.script +def pf_single_theta_dual_radio_update(weights, particles, z, empirical_dist, offsets): + z0 = theta_phi_to_p_vec( + torch_pi_norm_pi(particles[:, 0] - offsets[0]), + z[0], + empirical_dist[0], + ) + z1 = theta_phi_to_p_vec( + torch_pi_norm_pi(particles[:, 0] - offsets[1]), + z[1], + empirical_dist[1], + ) + weights = weights * z0 * z1 + 1.0e-30 # avoid round-off to zero + return weights / torch.sum(weights) # normalize + + class PFSingleThetaDualRadio(ParticleFilter): def __init__(self, ds): self.ds = ds - self.offsets = [ - ds.yaml_config["receivers"][0]["theta-in-pis"] * torch.pi, - ds.yaml_config["receivers"][1]["theta-in-pis"] * torch.pi, - ] + self.offsets = torch.tensor( + [ + ds.yaml_config["receivers"][0]["theta-in-pis"] * torch.pi, + ds.yaml_config["receivers"][1]["theta-in-pis"] * torch.pi, + ] + ) self.generator = torch.Generator() self.generator.manual_seed(0) + self.cached_empirical_dist = torch.vstack( + [ + self.ds.get_empirical_dist(0).T.unsqueeze(0), + self.ds.get_empirical_dist(1).T.unsqueeze(0), + ] + ) + if not self.ds.temp_file: + self.all_observations = torch.vstack( + [self.ds.mean_phase["r0"], self.ds.mean_phase["r1"]] + ).T def observation(self, idx): + if not self.ds.temp_file: + return self.all_observations[idx] return torch.concatenate( [ self.ds[idx][0]["mean_phase_segmentation"].reshape(1), @@ -32,10 +60,6 @@ def observation(self, idx): axis=0, ) - @cache - def cached_empirical_dist(self, rx_idx): - return self.ds.get_empirical_dist(rx_idx).T - def fix_particles(self): self.particles[:, 0] = torch_pi_norm_pi(self.particles[:, 0]) @@ -46,18 +70,13 @@ def predict(self, our_state, dt, noise_std): add_noise(self.particles, noise_std=noise_std, generator=self.generator) def update(self, z): - self.weights *= theta_phi_to_p_vec( - torch_pi_norm_pi(self.particles[:, 0] - self.offsets[0]), - z[0], - self.cached_empirical_dist(0), - ) - self.weights *= theta_phi_to_p_vec( - torch_pi_norm_pi(self.particles[:, 0] - self.offsets[1]), - z[1], - self.cached_empirical_dist(1), + self.weights = pf_single_theta_dual_radio_update( + self.weights, + self.particles, + z, + self.cached_empirical_dist, + self.offsets, ) - self.weights += 1.0e-30 # avoid round-off to zero - self.weights /= torch.sum(self.weights) # normalize def metrics(self, trajectory): return dual_radio_mse_theta_metrics( diff --git a/spf/model_training_and_inference/models/particle_filter.py b/spf/model_training_and_inference/models/particle_filter.py index 1321f3f..70f47a8 100644 --- a/spf/model_training_and_inference/models/particle_filter.py +++ b/spf/model_training_and_inference/models/particle_filter.py @@ -6,7 +6,7 @@ import torch import tqdm -from spf.dataset.spf_dataset import v5spfdataset +from spf.dataset.spf_dataset import v5spfdataset, v5spfdataset_manager from spf.filters.particle_dualradio_filter import PFSingleThetaDualRadio from spf.filters.particle_dualradioXY_filter import PFXYDualRadio from spf.filters.particle_single_radio_filter import PFSingleThetaSingleRadio @@ -14,19 +14,13 @@ torch.set_num_threads(1) -def run_single_theta_single_radio( - ds_fn, - precompute_fn, - empirical_pkl_fn, - theta_err=0.1, - theta_dot_err=0.001, - N=128, -): - ds = v5spfdataset( - ds_fn, +def run_jobs_with_one_dataset(kwargs): + results = [] + with v5spfdataset_manager( + kwargs["ds_fn"], nthetas=65, ignore_qc=True, - precompute_cache=precompute_fn, + precompute_cache=kwargs["precompute_cache"], paired=True, snapshots_per_session=1, readahead=True, @@ -39,9 +33,26 @@ def run_single_theta_single_radio( "simple_segmentations", ] ), - empirical_data_fn=empirical_pkl_fn, - ) + empirical_data_fn=kwargs["empirical_pkl_fn"], + ) as ds: + for fn, fn_kwargs in kwargs["jobs"]: + + fn_kwargs["ds"] = ds + new_results = fn(**fn_kwargs) + for result in new_results: + result["ds_fn"] = kwargs["ds_fn"] + results += new_results + return results + + +def run_single_theta_single_radio( + ds, + theta_err=0.1, + theta_dot_err=0.001, + N=128, +): metrics = [] + for rx_idx in [0, 1]: pf = PFSingleThetaSingleRadio(ds=ds, rx_idx=rx_idx) trajectory = pf.trajectory( @@ -54,7 +65,6 @@ def run_single_theta_single_radio( metrics.append( { "type": "single_theta_single_radio", - "ds_fn": ds_fn, "rx_idx": rx_idx, "theta_err": theta_err, "theta_dot_err": theta_dot_err, @@ -65,26 +75,8 @@ def run_single_theta_single_radio( return metrics -def run_single_theta_dual_radio( - ds_fn, precompute_fn, empirical_pkl_fn, theta_err=0.1, theta_dot_err=0.001, N=128 -): - ds = v5spfdataset( - ds_fn, - nthetas=65, - ignore_qc=True, - precompute_cache=precompute_fn, - paired=True, - snapshots_per_session=1, - skip_fields=set( - [ - "windowed_beamformer", - "all_windows_stats", - "downsampled_segmentation_mask", - "signal_matrix", - ] - ), - empirical_data_fn=empirical_pkl_fn, - ) +def run_single_theta_dual_radio(ds, theta_err=0.1, theta_dot_err=0.001, N=128): + pf = PFSingleThetaDualRadio(ds=ds) traj_paired = pf.trajectory( mean=torch.tensor([[0, 0]]), @@ -97,7 +89,6 @@ def run_single_theta_dual_radio( return [ { "type": "single_theta_dual_radio", - "ds_fn": ds_fn, "theta_err": theta_err, "theta_dot_err": theta_dot_err, "N": N, @@ -106,26 +97,8 @@ def run_single_theta_dual_radio( ] -def run_xy_dual_radio( - ds_fn, precompute_fn, empirical_pkl_fn, pos_err=15, vel_err=0.5, N=128 * 16 -): - ds = v5spfdataset( - ds_fn, - nthetas=65, - ignore_qc=True, - precompute_cache=precompute_fn, - paired=True, - snapshots_per_session=1, - skip_fields=set( - [ - "windowed_beamformer", - "all_windows_stats", - "downsampled_segmentation_mask", - "signal_matrix", - ] - ), - empirical_data_fn=empirical_pkl_fn, - ) +def run_xy_dual_radio(ds, pos_err=15, vel_err=0.5, N=128 * 16): + # dual radio dual pf = PFXYDualRadio(ds=ds) traj_paired = pf.trajectory( @@ -138,7 +111,6 @@ def run_xy_dual_radio( return [ { "type": "xy_dual_radio", - "ds_fn": ds_fn, "vel_err": vel_err, "pos_err": pos_err, "N": N, @@ -187,7 +159,7 @@ def get_parser(): required=True, ) parser.add_argument( - "--full-p-fn", + "--empirical-pkl-fn", type=str, required=True, ) @@ -197,6 +169,11 @@ def get_parser(): default=0, required=False, ) + parser.add_argument( + "--output", + type=str, + required=True, + ) parser.add_argument( "--debug", action=argparse.BooleanOptionalAction, @@ -209,66 +186,88 @@ def get_parser(): args = parser.parse_args() random.seed(args.seed) - jobs = [] + jobs_per_ds_fn = [] - for ds_fn in args.datasets: - for N in [128, 128 * 4, 128 * 8, 128 * 16]: - for theta_err in [0.1, 0.01, 0.001, 0.2]: - for theta_dot_err in [0.001, 0.0001, 0.01, 0.1]: - jobs.append( - ( - run_single_theta_single_radio, - { - "ds_fn": ds_fn, - "precompute_fn": args.precompute_cache, - "full_p_fn": args.full_p_fn, - "N": N, - "theta_err": theta_err, - "theta_dot_err": theta_dot_err, - }, - ) - ) - for theta_err in [0.1, 0.01, 0.001, 0.2]: - for theta_dot_err in [0.001, 0.0001, 0.01, 0.1]: - jobs.append( - ( - run_single_theta_dual_radio, - { - "ds_fn": ds_fn, - "precompute_fn": args.precompute_cache, - "full_p_fn": args.full_p_fn, - "N": N, - "theta_err": theta_err, - "theta_dot_err": theta_dot_err, - }, - ) - ) - for ds_fn in args.datasets: - for N in [128, 128 * 4, 128 * 8, 128 * 16, 128 * 32]: - for pos_err in [1000, 100, 50, 30, 15, 5, 0.5]: - for vel_err in [50, 5, 0.5, 0.05, 0.01, 0.001]: - jobs.append( - ( - run_xy_dual_radio, - { - "ds_fn": ds_fn, - "precompute_fn": args.precompute_cache, - "full_p_fn": args.full_p_fn, - "N": N, - "pos_err": pos_err, - "vel_err": vel_err, - }, - ) + # for N in [128, 128 * 4, 128 * 8, 128 * 16]: + # for theta_err in [0.1, 0.01, 0.001, 0.2]: + # for theta_dot_err in [0.001, 0.0001, 0.01, 0.1]: + # jobs_per_ds_fn.append( + # ( + # run_single_theta_single_radio, + # { + # "N": N, + # "theta_err": theta_err, + # "theta_dot_err": theta_dot_err, + # }, + # ) + # ) + # for theta_err in [0.1, 0.01, 0.001, 0.2]: + # for theta_dot_err in [0.001, 0.0001, 0.01, 0.1]: + # jobs_per_ds_fn.append( + # ( + # run_single_theta_dual_radio, + # { + # "N": N, + # "theta_err": theta_err, + # "theta_dot_err": theta_dot_err, + # }, + # ) + # ) + for N in [128, 128 * 4, 128 * 8, 128 * 16, 128 * 32]: + for pos_err in [1000, 100, 50, 30, 15, 5, 0.5]: + for vel_err in [50, 5, 0.5, 0.05, 0.01, 0.001]: + jobs_per_ds_fn.append( + ( + run_xy_dual_radio, + { + "N": N, + "pos_err": pos_err, + "vel_err": vel_err, + }, ) + ) + + random.shuffle(jobs_per_ds_fn) - random.shuffle(jobs) + # one job per dataset + jobs = [ + { + "ds_fn": ds_fn, + "precompute_cache": args.precompute_cache, + "empirical_pkl_fn": args.empirical_pkl_fn, + "jobs": jobs_per_ds_fn, + } + for ds_fn in args.datasets + ] + + jobs = [] + for ds_fn in args.datasets: + for job in jobs_per_ds_fn: + jobs.append( + { + "ds_fn": ds_fn, + "precompute_cache": args.precompute_cache, + "empirical_pkl_fn": args.empirical_pkl_fn, + "jobs": [job], + } + ) if args.debug: - results = list(tqdm.tqdm(map(runner, jobs), total=len(jobs))) + results = list( + tqdm.tqdm( + map(run_jobs_with_one_dataset, jobs), + total=len(jobs), + ) + ) else: with Pool(20) as pool: # cpu_count()) # cpu_count() // 4) - results = list(tqdm.tqdm(pool.imap(runner, jobs), total=len(jobs))) - pickle.dump(results, open("particle_filter_results2.pkl", "wb")) + results = list( + tqdm.tqdm( + pool.imap(run_jobs_with_one_dataset, jobs), + total=len(jobs), + ) + ) + pickle.dump(results, open(args.output, "wb")) # run_single_theta_single_radio() # run_single_theta_dual_radio( diff --git a/tests/test_particle_filter.py b/tests/test_particle_filter.py index a982b95..75867ba 100644 --- a/tests/test_particle_filter.py +++ b/tests/test_particle_filter.py @@ -1,13 +1,11 @@ -import pickle import random import tempfile -import pytest import torch from spf.dataset.fake_dataset import partial_dataset from spf.dataset.open_partial_ds import open_partial_dataset_and_check_some -from spf.dataset.spf_dataset import v5spfdataset +from spf.dataset.spf_dataset import v5spfdataset, v5spfdataset_manager from spf.filters.particle_dualradio_filter import plot_single_theta_dual_radio from spf.filters.particle_dualradioXY_filter import plot_xy_dual_radio from spf.filters.particle_single_radio_filter import plot_single_theta_single_radio @@ -20,7 +18,7 @@ def test_single_theta_single_radio(noise1_n128_obits2): dirname, empirical_pkl_fn, ds_fn = noise1_n128_obits2 - ds = v5spfdataset( + with v5spfdataset_manager( ds_fn, precompute_cache=dirname, nthetas=65, @@ -29,24 +27,22 @@ def test_single_theta_single_radio(noise1_n128_obits2): paired=True, ignore_qc=True, gpu=False, - ) - args = { - "ds_fn": ds_fn, - "precompute_fn": dirname, - "empirical_pkl_fn": empirical_pkl_fn, - "N": 1024 * 4, - "theta_err": 0.01, - "theta_dot_err": 0.01, - } - results = run_single_theta_single_radio(**args) - for result in results: - assert result["metrics"]["mse_single_radio_theta"] < 0.05 - plot_single_theta_single_radio(ds) + ) as ds: + args = { + "ds": ds, + "N": 1024 * 4, + "theta_err": 0.01, + "theta_dot_err": 0.01, + } + results = run_single_theta_single_radio(**args) + for result in results: + assert result["metrics"]["mse_single_radio_theta"] < 0.05 + plot_single_theta_single_radio(ds) def test_single_theta_dual_radio(noise1_n128_obits2): dirname, empirical_pkl_fn, ds_fn = noise1_n128_obits2 - ds = v5spfdataset( + with v5spfdataset_manager( ds_fn, precompute_cache=dirname, nthetas=65, @@ -55,23 +51,21 @@ def test_single_theta_dual_radio(noise1_n128_obits2): paired=True, ignore_qc=True, gpu=False, - ) - args = { - "ds_fn": ds_fn, - "precompute_fn": dirname, - "empirical_pkl_fn": empirical_pkl_fn, - "N": 1024 * 4, - "theta_err": 0.01, - "theta_dot_err": 0.01, - } - result = run_single_theta_dual_radio(**args) - assert result[0]["metrics"]["mse_craft_theta"] < 0.15 - plot_single_theta_dual_radio(ds) + ) as ds: + args = { + "ds": ds, + "N": 1024 * 4, + "theta_err": 0.01, + "theta_dot_err": 0.01, + } + result = run_single_theta_dual_radio(**args) + assert result[0]["metrics"]["mse_craft_theta"] < 0.15 + plot_single_theta_dual_radio(ds) def test_XY_dual_radio(noise1_n128_obits2): dirname, empirical_pkl_fn, ds_fn = noise1_n128_obits2 - ds = v5spfdataset( + with v5spfdataset_manager( ds_fn, precompute_cache=dirname, nthetas=65, @@ -80,19 +74,17 @@ def test_XY_dual_radio(noise1_n128_obits2): paired=True, ignore_qc=True, gpu=False, - ) - args = { - "ds_fn": ds_fn, - "precompute_fn": dirname, - "empirical_pkl_fn": empirical_pkl_fn, - "N": 1024 * 4, - "pos_err": 50, - "vel_err": 0.1, - } + ) as ds: + args = { + "ds": ds, + "N": 1024 * 4, + "pos_err": 50, + "vel_err": 0.1, + } - result = run_xy_dual_radio(**args) - assert result[0]["metrics"]["mse_craft_theta"] < 0.25 - plot_xy_dual_radio(ds) + result = run_xy_dual_radio(**args) + assert result[0]["metrics"]["mse_craft_theta"] < 0.25 + plot_xy_dual_radio(ds) def test_partial(noise1_n128_obits2): @@ -135,7 +127,7 @@ def test_partial(noise1_n128_obits2): def test_partial_script(noise1_n128_obits2): - dirname, empirical_pkl_fn, ds_fn = noise1_n128_obits2 + _, _, ds_fn = noise1_n128_obits2 with tempfile.TemporaryDirectory() as tmpdirname: ds_fn_out = f"{tmpdirname}/partial" for partial_n in [10, 20]: