Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Dec 9, 2024
1 parent 10974a9 commit 4ab935b
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 87 deletions.
47 changes: 28 additions & 19 deletions spf/model_training_and_inference/models/ekf_and_pf_config.yml
Original file line number Diff line number Diff line change
@@ -1,60 +1,69 @@
precompute_caches:
3.2: "/mnt/4tb_ssd/precompute_cache_new"
3.11: "/mnt/4tb_ssd/precompute_cache"
3.3: "/mnt/4tb_ssd/precompute_cache_3p3"
# 3.2: "/mnt/4tb_ssd/precompute_cache_new"
# 3.11: "/mnt/4tb_ssd/precompute_cache"
# 3.3: "/mnt/4tb_ssd/precompute_cache_3p3"
3.4: "/mnt/4tb_ssd/precompute_cache_3p4"

runs:
- run_EKF_single_theta_single_radio:
phi_std: [20, 18,16,10.0, 5.0, 2.5,1.0]
p: [10.0, 5.0, 2.5, 1.0,0.5,0.1]
noise_std: [0.1, 0.01, 0.001, 0.0001,0.00001,0.0005,0.0002, 0.00002,0.00005]
dynamic_R: [0.0]
segmentation_version: [ 3.11 , 3.2 ]
segmentation_version: [ 3.4]
- run_EKF_single_theta_single_radio:
phi_std: [0.0]
p: [10.0, 5.0, 2.5, 1.0,0.5,0.1]
noise_std: [0.1, 0.01, 0.001, 0.0001, 0.00001, 0.0005,0.0002, 0.00002,0.00005]
dynamic_R: [1.0, 0.1]
segmentation_version: [ 3.11 , 3.2 ]
segmentation_version: [ 3.4 ]
- run_EKF_single_theta_dual_radio:
phi_std: [20,18,16, 14,12,10.0,8, 5.0, 2.5,1.0]
p: [10.0, 5.0, 2.5, 1.0, 0.5, 0.1]
noise_std: [0.1, 0.01, 0.001, 0.0001, 0.00001, 0.0005,0.0002,0.000001, 0.00002,0.00005,0.000002,0.000005]
dynamic_R: [0.0]
segmentation_version: [ 3.11 , 3.2 ]
segmentation_version: [ 3.4]
- run_EKF_single_theta_dual_radio:
phi_std: [0.0]
p: [10.0, 5.0, 2.5, 1.0, 0.5, 0.1]
noise_std: [0.1, 0.01, 0.001, 0.0001,0.00001, 0.0005, 0.0002,0.000001, 0.00002,0.00005,0.000002,0.000005]
dynamic_R: [1.0,0.1]
segmentation_version: [ 3.11 , 3.2 ]
segmentation_version: [ 3.4]
- run_PF_single_theta_single_radio:
N: [128, 128 * 4, 128 * 8, 128 * 16, 128 * 32, 128 * 64, 128*128,128*256]
theta_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.075, 0.1, 0.2]
theta_dot_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.1, 0.2]
segmentation_version: [ 3.11 , 3.2 ]
segmentation_version: [ 3.4 ]
- run_PF_single_theta_dual_radio:
N: [128, 128 * 4, 128 * 8, 128 * 16, 128 * 32, 128 * 64, 128*128,128*256]
theta_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.075, 0.1, 0.2]
theta_dot_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.1, 0.2]
segmentation_version: [ 3.11 , 3.2 ]
segmentation_version: [ 3.4 ]
- run_PF_single_theta_single_radio_NN:
checkpoint_fn_and_segmentation_version:
- checkpoint_fn: "/home/mouse9911/gits/spf/nov22_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_beamformerOnly_withMag/best.pth"
segmentation_version: 3.2
- checkpoint_fn: "/home/mouse9911/gits/spf/nov26_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05/best.pth"
segmentation_version: 3.2
inference_cache: ["/mnt/4tb_ssd/inference_cache_nov29/"]
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4/best.pth"
segmentation_version: 3.4
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4_windows_windowedbeamformer_nophase_nobeam_noemp_normalized/best.pth"
segmentation_version: 3.4
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4_windows_windowedbeamformer_nophase_nobeam_noemp_normalized_big2/best.pth"
segmentation_version: 3.4
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4_windows_fix_rerun/best.pth"
segmentation_version: 3.4
inference_cache: ["/mnt/md2/cache/inference"]
N: [128, 128 * 4, 128 * 8, 128 * 16, 128 * 32, 128 * 64 , 128*128 ]
theta_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.075, 0.1, 0.15, 0.2]
theta_dot_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.1, 0.2]
- run_PF_single_theta_dual_radio_NN:
checkpoint_fn_and_segmentation_version:
- checkpoint_fn: "/home/mouse9911/gits/spf/nov22_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_beamformerOnly_withMag/best.pth"
segmentation_version: 3.2
- checkpoint_fn: "/home/mouse9911/gits/spf/nov26_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05/best.pth"
segmentation_version: 3.2
inference_cache: ["/mnt/4tb_ssd/inference_cache_nov29/"]
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4/best.pth"
segmentation_version: 3.4
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4_windows_windowedbeamformer_nophase_nobeam_noemp_normalized/best.pth"
segmentation_version: 3.4
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4_windows_windowedbeamformer_nophase_nobeam_noemp_normalized_big2/best.pth"
segmentation_version: 3.4
- checkpoint_fn: "/home/mouse9911/gits/spf/dec3_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05_rerun_3p4_windows_fix_rerun/best.pth"
segmentation_version: 3.4
inference_cache: ["/mnt/md2/cache/inference"]
N: [128, 128 * 4, 128 * 8, 128 * 16, 128 * 32, 128 * 64 , 128*128,128*256 ]
theta_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.075, 0.1, 0.15,0.2]
theta_dot_err: [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.05,0.075,0.09,0.1,0.12, 0.15, 0.2]
Expand Down
155 changes: 128 additions & 27 deletions spf/model_training_and_inference/models/single_point_networks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import lru_cache

import torch
from torch import nn
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, model_config, global_config):
> 1
)

def prepare_input(self, batch):
def prepare_input(self, batch, additional_inputs=[]):
dropout_mask = (
torch.rand((5, *batch["y_rad"].shape), device=batch["y_rad"].device)
< self.input_dropout
Expand Down Expand Up @@ -97,17 +98,27 @@ def prepare_input(self, batch):
if self.training:
v[dropout_mask[4]] = 0
inputs.append(v)
return torch.concatenate(inputs, axis=2)

return torch.concatenate(
inputs + additional_inputs,
dim=2,
)


def sigmoid_dist(x):
x = torch.nn.functional.sigmoid(x)
return x / x.sum(dim=-1, keepdim=True)


# Keep track of 10 different messages and then warn again
@lru_cache(1)
def warn_ntheta():
logging.warning("output_ntheta is not specified, defaulting to global_config")


def check_and_load_ntheta(model_config, global_config):
if "output_ntheta" not in model_config:
logging.warning("output_ntheta is not specified, defaulting to global_config")
warn_ntheta()
model_config["output_ntheta"] = global_config["nthetas"]


Expand Down Expand Up @@ -142,6 +153,18 @@ def forward(self, x):
return self.conv_net(x).mean(axis=2)


class SkipConnect(nn.Module):
def __init__(self, internal_module):
super().__init__()
self.internal_module = internal_module

def forward(self, x):
out = self.internal_module(x)
if out.shape == x.shape:
return out + x
return out


class AllWindowsStatsNet(nn.Module):
def __init__(
self,
Expand All @@ -154,50 +177,126 @@ def __init__(
norm_size=256,
n_layers=1,
windowed_beamformer=False,
normalize_windowed_beamformer=False,
nthetas=0,
act="relu",
):
super().__init__()
if act == "relu":
self.act = nn.ReLU
elif act == "leaky":
self.act = nn.LeakyReLU
elif act == "selu":
self.act = nn.SELU
else:
raise ValueError("invalid activation")
padding = k // 2
self.outputs = outputs
self.output_phi = output_phi
self.dropout = dropout
self.windowed_beamformer = windowed_beamformer
self.normalize_windowed_beamformer = normalize_windowed_beamformer
self.nthetas = nthetas
assert not dropout or not norm, "currently cannot do norm if dropout > 0.0"
input_channels = 3
if self.windowed_beamformer:
input_channels += self.nthetas
layers = [
nn.Conv1d(input_channels, hidden_channels, k, stride=1, padding=padding),
# nn.LayerNorm(norm_size) if norm else nn.Identity(),
nn.ReLU(),
nn.Conv1d(hidden_channels, hidden_channels, k, stride=2, padding=padding),
nn.LayerNorm(norm_size // 2) if norm else nn.Identity(),
nn.ReLU(),
SkipConnect(
nn.Sequential(
nn.Conv1d(
input_channels, hidden_channels, k, stride=1, padding=padding
),
(
nn.LayerNorm([hidden_channels, norm_size])
if norm
else nn.Identity()
),
self.act(),
)
),
]
# added extra layer here
for idx in range(n_layers - 1):
layers += [
SkipConnect(
nn.Sequential(
nn.Conv1d(
hidden_channels,
hidden_channels,
k,
stride=1,
padding=padding,
),
(
nn.LayerNorm([hidden_channels, norm_size])
if norm
else nn.Identity()
),
self.act(),
)
),
]
layers += [
SkipConnect(
nn.Sequential(
nn.Conv1d(
hidden_channels, hidden_channels, k, stride=2, padding=padding
),
(
nn.LayerNorm([hidden_channels, norm_size // 2])
if norm
else nn.Identity()
),
self.act(),
)
),
]
for idx in range(n_layers):
layers += [
nn.Conv1d(
hidden_channels, hidden_channels, k, stride=1, padding=padding
SkipConnect(
nn.Sequential(
nn.Conv1d(
hidden_channels,
hidden_channels,
k,
stride=1,
padding=padding,
),
(
nn.LayerNorm([hidden_channels, norm_size // 2])
if norm
else nn.Identity()
),
self.act(),
)
),
nn.LayerNorm(norm_size // 2) if norm else nn.Identity(),
nn.ReLU(),
]
layers += [
nn.Conv1d(hidden_channels, hidden_channels, k, stride=2, padding=padding),
nn.LayerNorm(norm_size // 4) if norm else nn.Identity(),
nn.ReLU(),
SkipConnect(
nn.Sequential(
nn.Conv1d(
hidden_channels, hidden_channels, k, stride=2, padding=padding
),
(
nn.LayerNorm([hidden_channels, norm_size // 4])
if norm
else nn.Identity()
),
self.act(),
)
),
nn.Conv1d(hidden_channels, outputs, k, stride=1, padding=padding),
]
self.conv_net = torch.nn.Sequential(*layers)
if output_phi:
self.phi_network = torch.nn.Sequential(
nn.Linear(outputs, hidden_channels),
nn.ReLU(),
self.act(),
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
self.act(),
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
self.act(),
nn.Linear(hidden_channels, 1),
)

Expand All @@ -213,7 +312,15 @@ def forward(self, batch):

inputs = [all_windows_normalized_input]
if self.windowed_beamformer:
inputs.append(batch["windowed_beamformer"].transpose(2, 3) / 500)
if self.normalize_windowed_beamformer:
inputs.append(
torch.nn.functional.normalize(
batch["windowed_beamformer"].transpose(2, 3), p=1, dim=2
)
- 1 / 65.0
)
else:
inputs.append(batch["windowed_beamformer"].transpose(2, 3) / 500)
input = torch.concatenate(inputs, dim=2)

size_batch, size_snapshots, channels, windows = input.shape
Expand Down Expand Up @@ -311,13 +418,7 @@ def forward(self, batch):
)
return_dict["single"] = torch.nn.functional.normalize(
self.single_point_with_beamformer_ffnn(
torch.concatenate(
[
self.prepare_input.prepare_input(batch),
]
+ additional_inputs,
dim=2,
)
self.prepare_input.prepare_input(batch, additional_inputs),
).abs(),
dim=2,
p=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_inference_on_ds_noexceptions(
segmentation_version=segmentation_version,
)
except Exception as e:
logging.error(f"Failed to process {ds_fn} with {e}")
logging.error(f"Failed to process {ds_fn} with {str(e)}")


def get_inference_on_ds(
Expand Down Expand Up @@ -124,9 +124,11 @@ def get_inference_on_ds(
precompute_cache=precompute_cache,
segmentation_version=segmentation_version,
)

config_checksum = get_md5_of_file(config_fn)
checkpoint_checksum = get_md5_of_file(checkpoint_fn)
logging.debug(
f"checkpoint_checksum {checkpoint_checksum} config_checksum {config_checksum}"
)
ds_basename = os.path.basename(ds_fn)
inference_cache_fn = f"{inference_cache}/{ds_basename}/{segmentation_version:0.3f}/{checkpoint_checksum}/{config_checksum}.npz"
if os.path.exists(inference_cache_fn):
Expand Down
Loading

0 comments on commit 4ab935b

Please sign in to comment.