Skip to content

Commit

Permalink
fix compat; add scripts to run inference and filters; fix nan issue w…
Browse files Browse the repository at this point in the history
…ith filters
  • Loading branch information
misko committed Nov 26, 2024
1 parent 4fc06c5 commit 101f847
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 16 deletions.
4 changes: 2 additions & 2 deletions spf/filters/ekf_dualradio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def trajectory(

# compute update = likelihood * prior
observation = self.observation(idx)

self.update(observation=observation)
if observation.isfinite().all():
self.update(observation=observation)

current_instance = {
"mu": self.x,
Expand Down
4 changes: 3 additions & 1 deletion spf/filters/ekf_single_radio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def trajectory(

# compute update = likelihood * prior
observation = self.observation(idx)
self.update(observation=observation)

if observation.isfinite().all():
self.update(observation=observation)

current_instance = {
"mu": self.x,
Expand Down
14 changes: 8 additions & 6 deletions spf/filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,14 @@ def trajectory(
)
self.fix_particles()

self.update(z=self.observation(idx))

# resample if too few effective particles
if neff(self.weights) < N / 2:
indexes = torch.as_tensor(systematic_resample(self.weights.numpy()))
resample_from_index(self.particles, self.weights, indexes)
z = self.observation(idx)
if z.isfinite().all():
self.update(z)

# resample if too few effective particles
if neff(self.weights) < N / 2:
indexes = torch.as_tensor(systematic_resample(self.weights.numpy()))
resample_from_index(self.particles, self.weights, indexes)

mu, var = estimate(self.particles, self.weights)

Expand Down
11 changes: 11 additions & 0 deletions spf/scripts/inference_cache.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash
if [ $# -ne 1 ]; then
echo $0 checkpoint_dir
exit
fi
dir=$1
precompute_cache=/mnt/4tb_ssd/precompute_cache_new
input_zarrs=/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_*.zarr
root=/home/mouse9911/gits/spf/
inference_cache=/mnt/4tb_ssd/inference_cache/
python ${root}/spf/scripts/create_inference_cache.py --inference-cache ${inference_cache} --config-fn ${dir}/config.yml --checkpoint-fn ${dir}/best.pth --device cuda --datasets ${input_zarrs} --parallel 24 --precompute-cache ${precompute_cache}
13 changes: 13 additions & 0 deletions spf/scripts/run_filters.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
val_file=/mnt/4tb_ssd/nosig_data/nov23_val.txt
root=/home/mouse9911/gits/spf/
precompute_cache=/mnt/4tb_ssd/precompute_cache_new
input_zarrs=/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_*.zarr
root=/home/mouse9911/gits/spf
inference_cache=/mnt/4tb_ssd/inference_cache/

cat ${val_file} | while read x; do
python ${root}/spf/scripts/run_filters_on_data.py -d $x --nthetas 65 --device cpu --skip-qc \
--precompute-cache ${precompute_cache} --empirical-pkl-fn ${root}/empirical_dists/full.pkl \
--parallel 24 --work-dir ${root}/spf/run_on_filters_nov25 --config ${root}/spf/model_training_and_inference/models/ekf_and_pf_config.yml
done
5 changes: 0 additions & 5 deletions spf/scripts/run_filters_on_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,6 @@ def get_parser():
default=30,
required=False,
)
parser.add_argument(
"--output",
type=str,
required=True,
)
parser.add_argument(
"--work-dir",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion spf/scripts/segment_zarr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ fi
cat $1 | while read line; do
echo "processing $line"
fadvise -a willneeded $line
python segment_zarr.py --input-zarr $line --precompute-cache /mnt/4tb_ssd/precompute_cache_new/ --gpu -p 12
python segment_zarr.py --input-zarr $line --precompute-cache /mnt/4tb_ssd/precompute_cache_new/ --gpu -p 16
fadvise -a dontneed $line
done
6 changes: 5 additions & 1 deletion spf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,12 @@ def compare_and_copy(prefix, src, dst, skip_signal_matrix=False):
if isinstance(src, zarr.hierarchy.Group):
for key in src.keys():
if not skip_signal_matrix or key != "signal_matrix":
if key == "rx_heading" and key not in dst:
dst_key = "rx_heading_in_pis"
else:
dst_key = key
compare_and_copy(
prefix + "/" + key, src[key], dst[key], skip_signal_matrix
prefix + "/" + key, src[key], dst[dst_key], skip_signal_matrix
)
else:
if prefix == "/config":
Expand Down

0 comments on commit 101f847

Please sign in to comment.